Skip to content

Commit 5a9e52b

Browse files
committed
fixes vocab bug for gpt; refer #8
1 parent f926e7d commit 5a9e52b

File tree

7 files changed

+26
-10
lines changed

7 files changed

+26
-10
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ This repository provides the pytorch source code, and data for tabular transform
1818

1919
(X) represents the versions which code is tested on.
2020

21-
These can be installed using pip by running :
21+
These can be installed using yaml by running :
2222
```
23-
pip install -r requirements.txt
23+
conda env create -f setup.yml
2424
```
2525
---
2626

dataset/card.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def get_csv(self, fname):
239239
data = pd.read_csv(fname, nrows=self.nrows)
240240
if self.user_ids:
241241
log.info(f'Filtering data by user ids list: {self.user_ids}...')
242+
self.user_ids = map(int, self.user_ids)
242243
data = data[data['User'].isin(self.user_ids)]
244+
243245
self.nrows = data.shape[0]
244246
log.info(f"read data : {data.shape}")
245247
return data

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def main(args):
9090
tab_net = TabFormerGPT2(custom_special_tokens,
9191
vocab=vocab,
9292
field_ce=args.field_ce,
93-
flatten=args.flatten
93+
flatten=args.flatten,
9494
)
9595

9696
log.info(f"model initiated: {tab_net.model.__class__}")

models/modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def get_model(self, field_ce, flatten):
9797
else:
9898
model = GPT2LMHeadModel(self.config)
9999
if not flatten:
100-
tab_emb_config = ddict(ncols=self.ncols, vocab_size=len(self.vocab), hidden_size=self.config.hidden_size)
100+
tab_emb_config = ddict(vocab_size=len(self.vocab), hidden_size=self.config.hidden_size)
101101
model = TabFormerBaseModel(model, TabFormerEmbeddings(tab_emb_config))
102+
102103
return model

models/tabformer_gpt2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def forward(
4646

4747
seq_len = shift_logits.size(1)
4848
total_lm_loss = 0
49-
field_names = self.vocab.get_field_keys(input_only=True, ignore_special=True)
49+
field_names = self.vocab.get_field_keys(remove_target=True, ignore_special=True)
50+
5051
for field_idx, field_name in enumerate(field_names):
5152
col_ids = list(range(field_idx, seq_len, len(field_names)))
5253
global_ids_field = self.vocab.get_field_ids(field_name)

requirements.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

setup.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name: tabformer
2+
channels:
3+
- anaconda
4+
- pytorch
5+
- huggingface
6+
- conda-forge
7+
dependencies:
8+
- python>=3.8
9+
- pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
10+
- torchvision
11+
- pandas
12+
- scikit-learn
13+
- transformers
14+
- numpy
15+
- libgcc
16+
- pip:
17+
- transformers==3.2.0

0 commit comments

Comments
 (0)