Skip to content

Commit

Permalink
fix prev compability lopuhin#2
Browse files Browse the repository at this point in the history
  • Loading branch information
gsuszka committed Aug 9, 2019
1 parent 3aa162c commit d5944b1
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import attr
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
import torch.utils.checkpoint



@attr.s(auto_attribs=True, frozen=True)
class HParams:
Expand All @@ -18,7 +20,7 @@ class HParams:
n_hidden: int
n_head: int
n_layer: int
gradient_checkpointing: bool
gradient_checkpointing: bool = False


class Model(nn.Module):
Expand All @@ -31,8 +33,7 @@ def __init__(self, hparams: HParams, text_gen_mode: bool = False):
self.wte = nn.Embedding(hparams.n_vocab, hparams.n_embed)
nn.init.normal_(self.wte.weight, std=0.02)
self.blocks = nn.ModuleList(
[Block(hparams) for _ in range(hparams.n_layer)]
)
[Block(hparams) for _ in range(hparams.n_layer)])
self.ln_f = Norm(self.hparams.n_hidden)
if hparams.n_hidden != hparams.n_embed:
self.in_proj = Conv1D(hparams.n_embed, hparams.n_hidden)
Expand All @@ -44,9 +45,7 @@ def forward(self, x, past=None):
# Embedding
past_length = 0 if past is None else past.shape[-2]
batch_size, n_ctx = x.shape
position = position_for(
batch_size, n_ctx, past_length, x.device
)
position = position_for(batch_size, n_ctx, past_length, x.device)
h = self.wte(x) + self.wpe(position)
assert h.shape == (batch_size, n_ctx, self.hparams.n_embed)
if self.in_proj:
Expand All @@ -56,9 +55,11 @@ def forward(self, x, past=None):
# presents = []
for i, block in enumerate(self.blocks):
if getattr(self.hparams, 'gradient_checkpointing', False):
h, present = torch.utils.checkpoint.checkpoint(block, h, past[:, i] if past is not None else None)
h, present = torch.utils.checkpoint.checkpoint(
block, h, past[:, i] if past is not None else None)
else:
h, present = block(h, past=past[:, i] if past is not None else None)
h, present = block(
h, past=past[:, i] if past is not None else None)
# presents.append(present)
h = self.ln_f(h)

Expand All @@ -68,13 +69,9 @@ def forward(self, x, past=None):
output = {"hidden": h}

if self._text_gen_mode:
h_flat = h.reshape(
[batch_size * n_ctx, self.hparams.n_embed]
)
h_flat = h.reshape([batch_size * n_ctx, self.hparams.n_embed])
logits = torch.matmul(h_flat, self.wte.weight.t())
logits = logits.reshape(
[batch_size, n_ctx, self.hparams.n_vocab]
)
logits = logits.reshape([batch_size, n_ctx, self.hparams.n_vocab])
output["logits"] = logits

return output
Expand All @@ -99,7 +96,6 @@ def forward(self, x, past):
class Norm(nn.Module):
""" Normalize to mean = 0, std = 1, then do a diagonal affine transform.
"""

def __init__(self, n_features, *, dim=-1, epsilon=1e-5):
super().__init__()
self.n_features = n_features
Expand Down Expand Up @@ -145,9 +141,7 @@ def forward(self, x, past):
assert past.shape[-1] == self.hparams.n_hidden

c = self.c_attn(x)
q, k, v = map(
self.split_heads, torch.split(c, x.shape[-1], dim=2)
)
q, k, v = map(self.split_heads, torch.split(c, x.shape[-1], dim=2))

present = torch.stack([k, v], dim=1)

Expand All @@ -166,9 +160,7 @@ def split_heads(self, x):
""" From [batch, sequence, features] to
[batch, heads, sequence, features].
"""
return self.split_states(x, self.hparams.n_head).permute(
0, 2, 1, 3
)
return self.split_states(x, self.hparams.n_head).permute(0, 2, 1, 3)

@staticmethod
def split_states(x, n):
Expand Down Expand Up @@ -225,14 +217,9 @@ def reset_parameters(self):


def gelu(x, c=math.sqrt(2 / math.pi)):
return (
0.5 * x * (1 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3))))
)
return (0.5 * x * (1 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3)))))


def position_for(batch_size, n_steps, past_length, device=None):
return (
torch.arange(past_length, n_steps + past_length, device=device)
.unsqueeze(0)
.repeat(batch_size, 1)
)
return (torch.arange(past_length, n_steps + past_length,
device=device).unsqueeze(0).repeat(batch_size, 1))

0 comments on commit d5944b1

Please sign in to comment.