Skip to content

Commit

Permalink
Add prompt_method and neg_prob to the phrasecut dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
rabinadk1 committed Dec 30, 2023
1 parent aedcf8e commit 962f68b
Show file tree
Hide file tree
Showing 4 changed files with 1,779 additions and 29 deletions.
6 changes: 6 additions & 0 deletions configs/data/phrasecut.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ train_ds:
tokenizer_pretrained_path: ${tokenizer_pretrained_path}
transforms: ${train_transforms}
return_tensors: pt
prompt_method: shuffle+
neg_prob: 0.2

val_ds:
_target_: src.data.core_datasets.phrasecutdataset.PhraseCutDataset
Expand All @@ -14,6 +16,8 @@ val_ds:
tokenizer_pretrained_path: ${tokenizer_pretrained_path}
transforms: ${val_transforms}
return_tensors: pt
prompt_method: fixed
neg_prob: 0

test_ds:
_target_: src.data.core_datasets.phrasecutdataset.PhraseCutDataset
Expand All @@ -22,6 +26,8 @@ test_ds:
tokenizer_pretrained_path: ${tokenizer_pretrained_path}
transforms: ${test_transforms}
return_tensors: pt
prompt_method: fixed
neg_prob: 0

batch_size: 16
num_workers: 4
Expand Down
26 changes: 18 additions & 8 deletions configs/experiment/phrasecut.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ model:
factor: 0.1
patience: 5
net:
freeze_image_encoder: false
freeze_text_encoder: false
freeze_image_encoder: true
freeze_text_encoder: true
add_pos_enc: false
decoder_layer_kwargs:
nhead: 8
Expand All @@ -47,7 +47,7 @@ model:
compile: false # Torch compile works only when CUDA Compat >= 7.0

data:
batch_size: 32
batch_size: 64

logger:
wandb:
Expand All @@ -61,7 +61,7 @@ logger:
# Additional parameters #
############################

model_pretrained_path: openai/clip-vit-base-patch32
model_pretrained_path: openai/clip-vit-base-patch16
tokenizer_pretrained_path: ${model_pretrained_path}
image_pretrained_path: ${model_pretrained_path}
text_pretrained_path: ${model_pretrained_path}
Expand All @@ -73,19 +73,29 @@ img_mean: [0.48145466, 0.4578275, 0.40821073]
img_std: [0.26862954, 0.26130258, 0.27577711]

# Image pre-processing configs
train_transforms:
train_transforms:
_target_: albumentations.Compose
transforms:
- _target_: albumentations.Resize
height: ${img_size}
- _target_: albumentations.SmallestMaxSize
max_size: ${img_size}
- _target_: albumentations.Rotate
limit: 10
border_mode: 1 # cv2.BORDER_REPLICATE
p: 0.2
- _target_: albumentations.RandomCrop
width: ${img_size}
height: ${img_size}
- _target_: albumentations.RandomBrightnessContrast
contrast_limit: 0.1
brightness_limit: 0.1
brightness_by_max: false
- _target_: albumentations.Normalize
mean: ${img_mean}
std: ${img_std}
- _target_: albumentations.pytorch.ToTensorV2
transpose_mask: true

eval_transforms:
eval_transforms:
_target_: albumentations.Compose
transforms:
- _target_: albumentations.Resize
Expand Down
Loading

0 comments on commit 962f68b

Please sign in to comment.