Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added attention consensus model #157

Open
wants to merge 6 commits into
base: dev-v0.1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_nn_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Copyright 2020 NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import torch

from variantworks.layers.attention import Attention


def test_attention_layer():
input_tensor = torch.zeros((10, 10, 5), dtype=torch.float32)
attn_layer = Attention(5)
out, _ = attn_layer(input_tensor, input_tensor)
assert(torch.all(input_tensor.eq(out)))
126 changes: 126 additions & 0 deletions variantworks/layers/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#
# Copyright 2020 NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# The implementation in this file is adopted from a 3rd party repository with BSD 3-Clause License.
# BSD 3-Clause License
#
# Copyright (c) James Bradbury and Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Attention related layers."""

import torch
import torch.nn as nn


class Attention(nn.Module):
"""Applies attention mechanism on the `context` using the `query`.

Implementation from: https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/attention.html
"""

def __init__(self, dimensions, attention_type='general'):
"""Construct an Attention layer.

Args:
dimensions (int): Dimensionality of the query and context.
attention_type (str, optional): How to compute the attention score:

* dot: :math:`score(H_j,q) = H_j^T q`
* general: :math:`score(H_j, q) = H_j^T W_a q`
"""
super(Attention, self).__init__()

if attention_type not in ['dot', 'general']:
raise ValueError('Invalid attention type selected.')

self.attention_type = attention_type
if self.attention_type == 'general':
self.linear_in = nn.Linear(dimensions, dimensions, bias=False)

self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
self.softmax = nn.Softmax(dim=-1)
self.tanh = nn.Tanh()

def forward(self, query, context):
"""Forward method.

Args:
query : Sequence of queries to query the \
context [batch size, output length, dimensions].
context : Data over which to apply the attention \
mechanism [batch size, query length, dimensions].

Returns:
Tuple with output and weights:
* output : Tensor containing the attended features [batch size, output length, dimensions].
* weights : Tensor containing attention weights [batch size, output length, query length].
"""
batch_size, output_len, dimensions = query.size()
query_len = context.size(1)

if self.attention_type == "general":
query = query.reshape(batch_size * output_len, dimensions)
query = self.linear_in(query)
query = query.reshape(batch_size, output_len, dimensions)

# (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
# (batch_size, output_len, query_len)
attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())

# Compute weights across every context sequence
attention_scores = attention_scores.view(batch_size * output_len, query_len)
attention_weights = self.softmax(attention_scores)
attention_weights = attention_weights.view(batch_size, output_len, query_len)

# (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
# (batch_size, output_len, dimensions)
mix = torch.bmm(attention_weights, context)

# concat -> (batch_size * output_len, 2*dimensions)
combined = torch.cat((mix, query), dim=2)
combined = combined.view(batch_size * output_len, 2 * dimensions)

# Apply linear_out on every 2nd dimension of concat
# output -> (batch_size, output_len, dimensions)
output = self.linear_out(combined).view(batch_size, output_len, dimensions)
output = self.tanh(output)

return output, attention_weights
67 changes: 67 additions & 0 deletions variantworks/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from nemo.core.neural_types import NeuralType, ChannelType, LogitsType
from nemo.core.neural_factory import DeviceType

from variantworks.layers.attention import Attention


class AlexNet(TrainableNM):
"""A Neural Module for AlexNet."""
Expand Down Expand Up @@ -254,3 +256,68 @@ def forward(self, encoding):
encoding = self.classifier(encoding)
encoding = F.softmax(encoding, dim=2)
return encoding


class ConsensusAttention(TrainableNM):
"""A Neural Module for training a Consensus Attention Model."""

@property
@add_port_docs()
def input_ports(self):
"""Return definitions of module input ports.

Returns:
Module input ports.
"""
return {
"encoding": NeuralType(('B', 'W', 'C'), ChannelType()),
}

@property
@add_port_docs()
def output_ports(self):
"""Return definitions of module output ports.

Returns:
Module output ports.
"""
return {
# Variant type
'output_logit': NeuralType(('B', 'W', 'D'), LogitsType()),
}

def __init__(self, sequence_length, input_feature_size, num_output_logits):
"""Construct an Consensus RNN NeMo instance.

Args:
sequence_length : Length of sequence to feed into RNN.
input_feature_size : Length of input feature set.
num_output_logits : Number of output classes of classifier.

Returns:
Instance of class.
"""
super().__init__()
self.num_output_logits = num_output_logits

self.attn = Attention(input_feature_size)
self.gru = nn.GRU(input_feature_size, 16, 1, batch_first=True, bidirectional=True)
self.classifier = nn.Linear(32, self.num_output_logits)

self._device = torch.device(
"cuda" if self.placement == DeviceType.GPU else "cpu")
self.to(self._device)

def forward(self, encoding):
"""Abstract function to run the network.

Args:
encoding : Input sequence to run network on.

Returns:
Output of forward pass.
"""
encoding, _ = self.attn(encoding, encoding)
encoding, _ = self.gru(encoding)
encoding = self.classifier(encoding)
return encoding