-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stative Attenders #569
base: master
Are you sure you want to change the base?
Stative Attenders #569
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a chance to take a look. Besides the minor comments on documentation, it basically looks fine, although I'm wondering if there would be a way to make the changes more modular and local to the attender, i.e. remove the need to call update()
from outside the attender. I'm not sure if that would work well with beam search etc. though?
-- Matthias
(this comment is "not a contribution")
|
||
class Attender(object): | ||
""" | ||
A template class for functions implementing attention. | ||
""" | ||
|
||
def init_sent(self, sent: expression_seqs.ExpressionSequence) -> None: | ||
def init_sent(self, sent: expression_seqs.ExpressionSequence) -> AttenderState: | ||
"""Args: | ||
sent: the encoder states, aka keys and values. Usually but not necessarily an :class:`expression_seqs.ExpressionSequence` | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return value needs documentation.
xnmt/modelparts/attenders.py
Outdated
hidden_dim=self.coverage_dim, | ||
param_init=param_init, | ||
bias_init=bias_init)) | ||
|
||
def init_sent(self, sent: expression_seqs.ExpressionSequence) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type outdated.
I = self.curr_sent.as_tensor() | ||
return I * attention | ||
|
||
def update(self, dec_state: decoders.DecoderState, att_state: AttenderState, attention: dy.Expression): | ||
return None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you document how update
's intended use is?
Thanks for the feedback, Matthias! I updated the documentation. I talked to Graham about the mechanism of compute_attention() and update(). I agree that it would be nice to factor this so the translator class doesn't have to call update() but I'm not sure that's possible in general. The real problem is that the attention vector returned by the attender may not be the "final" attention vector used downstream. For example, if one chooses to ensemble multiple attenders then the final attention vector will not be the same as the vector produced by any individual one. This mechanism allows for us to feed the real attention vector back into the attender even in these types of cases. I talked to Graham a bit about this, and this was the best solution we came up with. I'm happy to discuss further if you see a better way! |
I see, yeah I had suspected something like that. In that case I think this can be merged (once the merge conflicts are resolved)! -- Matthias |
@armatthews If you resolve the conflicts on this I think we can merge. |
This PR enables stative attenders, and contains a sample implementation of "Modeling Coverage for Neural Machine Translation" (Tu et al. 2016).