Skip to content

Commit

Permalink
instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
lrei committed May 19, 2023
1 parent ef26bdd commit 9edb883
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
# experiments_emolit
Experiments for the emotions in literature paper.
Train a model like in the emotions in literature paper.

Data: https://zenodo.org/record/7883954

Extract the data into the `data` dir so it looks like this: `data/emolit`.

Install the requirements in `requirements.txt` (consider virtualenv).

Change any parameters in the `soft_train.py` file (e.g. encoder model, batch
size, number of epochs, ...).

Run: `python soft_train.py`. This should train a model and save it to the `model` directory.
4 changes: 4 additions & 0 deletions data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore
5 changes: 5 additions & 0 deletions model/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

26 changes: 14 additions & 12 deletions soft_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,30 @@
from mutils import compute_metrics_soft, write_metrics
from loss import SoftBCETrainer

logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)

# Data
TRAIN_FILE = "trn.tsv"
VAL_FILE = "val.tsv"

DATA_DIR = "./data/emolit"
TST_FILE = "./data/emolit/gold.tsv"
output_dir = "./model"
NUM_EPOCHS = 10
BS = 16
LOAD_BEST = True


SEQLEN = 64
# Params
MODEL_NAME = "roberta-large"
SEQLEN = 64
BS = 16
NUM_EPOCHS = 10
LOAD_BEST = True
os.makedirs(output_dir, exist_ok=True)

# logging
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)


# Create the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "true" # ensure parallel
Expand Down

0 comments on commit 9edb883

Please sign in to comment.