Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to write the concat attention #3

Open
zsgchinese opened this issue Mar 14, 2018 · 5 comments
Open

how to write the concat attention #3

zsgchinese opened this issue Mar 14, 2018 · 5 comments

Comments

@zsgchinese
Copy link

I found you used the dot attention with self-attention, i dont know how to write the concat attention when socre is v * tanh(W[ht; hs]) rather than ht * hs
because i was a beginner in tensorflow . thanks!

@26hzhang
Copy link
Owner

@zsgchinese can you refer me a related paper for concat attention? I am not sure which one you are talking about

@zsgchinese
Copy link
Author

zsgchinese commented Mar 25, 2018

@IsaacChanghau http://www.emnlp2015.org/proceedings/EMNLP/pdf/EMNLP166.pdf
In 3.1 Global attention . in calcute scroce, you used the dot attention , i dont know how to wirte concat attention

@26hzhang
Copy link
Owner

26hzhang commented Apr 1, 2018

@zsgchinese Hi, I was on a long vacation, just come back, I will read the paper and try to write it for your reference.

@26hzhang
Copy link
Owner

26hzhang commented Apr 1, 2018

@zsgchinese Here is my thinkings about your request.

The dot-attention I used is inspired by the “Attention Is All You Need” (ref. https://arxiv.org/pdf/1706.03762.pdf), I think it is different from the dot method described in the paper you shared to me, and I also think that the seq2seq model of the paper you shared is not suitable to be used in the sequence labeling task here. Since machine translation task has two different inputs while training process, source language sentences (for encoding) and target language sentences (for decoding), while sequence labeling task only accepts single input.

But, only consider your request, I think here is two ways:

  1. If you want to write some codes similar to the paper “Effective Approaches to Attention-based Neural Machine Translation” you shared to me, which is encoder-decoder (seq2seq) model. Tensorflow provides the comprehensive wrapper for it, you can follow its tutorials: https://www.tensorflow.org/tutorials/seq2seq

  2. If not, I assume you want to use the similar mechanism described in the “Effective Approaches to Attention-based Neural Machine Translation” to compute each hidden output of the dynamic rnn. In this case, for each time slot, you need to consider the align weights according to the source hidden states, so I am thinking we can create an attention cell using tensorflow to tackle this issue. (not sure if this idea fits your requirement, but I test it works). See below:

The attention cell is built as:

import tensorflow as tf
from tensorflow.python.ops.rnn_cell import LSTMCell, GRUCell, RNNCell
from model.nns import dense


class AttentionCell(RNNCell):
    """A time-major Attention based RNN cell
    ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py"""
    def __init__(self, num_units, memory, cell_type='lstm'):
        """
        :param num_units: number of hidden units in attention cell
        :param memory: all the source hidden state, shape = (max_time, batch_size, dim)
        :param cell_type: rnn cell type
        """
        super(AttentionCell, self).__init__()
        self._cell = LSTMCell(num_units) if cell_type == 'lstm' else GRUCell(num_units)
        self.num_units = num_units
        self.memory = memory
        self.mem_units = memory.get_shape().as_list()[-1]

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._cell.output_size

    def __call__(self, inputs, state, scope=None):
        c, m = state  # c is previous cell state, m is previous hidden state
        concat1 = tf.nn.tanh(tf.add(self.memory, dense(m, self.mem_units, use_bias=False, scope='concat')))
        alphas = tf.squeeze(tf.exp(dense(concat1, hidden=1, use_bias=False, scope='raw_alphas')), axis=[-1])
        alphas = tf.div(alphas, tf.reduce_sum(alphas, axis=0, keep_dims=True))  # (max_time, batch_size)
        w_context = tf.reduce_sum(tf.multiply(self.memory, tf.expand_dims(alphas, axis=-1)), axis=0)
        h, new_state = self._cell(inputs, state)
        concat2 = tf.concat([w_context, h], axis=-1)
        output = tf.nn.tanh(dense(concat2, self.num_units, use_bias=False, scope='dense'))
        return output, new_state

To use the attention cell, assume you have a output from RNN layer, named as “context”:
The shape of “context” is (batch_size, max_time, num_units), since the attention cell I built is time_major based, you need to transpose the “context” first.

# ……
context = tf.transpose(context, [1, 0, 2])  # (max_time, batch_size, num_units)
att_cell = AttentionCell(num_units, context, cell_type=lstm’)  # create attention cell
# using dynamic rnn to compute output
att_output, _ = dynamic_rnn(att_cell, context, sequence_length=self.seq_len, dtype=tf.float32, time_major=True)
# transpose att_output back to bach_major
att_output = tf.transpose(att_output, [1, 0, 2]). # shape = (batch_size, max_time, num_units)
# ……

Then you can derive the attentive rnn outputs, and you can use these outputs to do further things.

Thanks.

Plus, I do not have any GPUs by my side currently, so I just test whether it is able to compile, not really train the model comprehensively.

@zsgchinese
Copy link
Author

@IsaacChanghau First of all , really thanks for your detailed reply. I have read your reply carefully. And I will say something about your answer with my understanding.
concat1 = tf.nn.tanh(tf.add(self.memory, dense(m, self.mem_units, use_bias=False, scope='concat')))
the shape of memory(context) is (batch_size, max_time, num_units), the shape of m is (batch_size, num_units)
the function call() should be called by every step of dynamic_rnn(that is decoder).
But the concat1 is the score ? (in the paper I refered in 3.1 Global Attention ) if it is , I dont find which method you use among three function(dot , general,concat).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants