Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Style2 #54

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
update
james20141606 committed Nov 3, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit c95c4e28779ca884b8abcc91337b13c56efbc826
Binary file added .DS_Store
Binary file not shown.
4 changes: 2 additions & 2 deletions ECoGDataSet.py
Original file line number Diff line number Diff line change
@@ -141,7 +141,7 @@ def __init__(self, ReqSubjDict, mode = 'train', train_param = None,BCTS=None,wor
self.BlockRegion = []
[self.BlockRegion.extend(self.cortex[area]) for area in train_param["BlockRegion"]]
self.wavebased = cfg.MODEL.WAVE_BASED
self.ReshapeAsGrid = False if 'Transformer' in cfg.MODEL.MAPPING_FROM_ECOG else True
self.ReshapeAsGrid = False if ('lstm') or ('Transformer') in cfg.MODEL.MAPPING_FROM_ECOG else True
self.Prod,self.UseGridOnly,self.SeqLen = train_param['Prod'],\
train_param['UseGridOnly'],\
train_param['SeqLen'],
@@ -339,7 +339,7 @@ def __init__(self, ReqSubjDict, mode = 'train', train_param = None,BCTS=None,wor
statics_ecog = baseline.mean(axis=0,keepdims=True)+1E-10, np.sqrt(baseline.var(axis=0, keepdims=True))+1E-10

ecog = (ecog - statics_ecog[0])/statics_ecog[1]
ecog = np.minimum(ecog,5)
ecog = np.minimum(ecog,10)#5)
ecog_len_+= [ecog.shape[0]]
ecog_+=[ecog]

21 changes: 16 additions & 5 deletions configs/ecog_style2.yaml
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ MODEL:
#####TAKE OFF CHECKLIST!!!########
N_FORMANTS: 6
N_FORMANTS_NOISE: 1
N_FORMANTS_ECOG: 2
N_FORMANTS_ECOG: 6
WAVE_BASED : True
DO_MEL_GUIDE : False
BGNOISE_FROMDATA: True
@@ -44,7 +44,15 @@ MODEL:
TRUNCATIOM_CUTOFF: 5
CHANNELS: 1
UNIQ_WORDS: 50
MAPPING_FROM_ECOG: "ECoGMappingBottleneck"
#MAPPING_FROM_ECOG: "ECoGMappingBottleneck" #ECoGMappingBottlenecklstm1, ECoGMappingBottlenecklstm2
#MAPPING_FROM_ECOG: "ECoGMappingBottlenecklstm1"
#MAPPING_FROM_ECOG: "ECoGMappingBottlenecklstm"
MAPPING_FROM_ECOG: "ECoGMappingBottlenecklstm_pure"
ONEDCONFIRST: True
RNN_TYPE: 'LSTM'
RNN_LAYERS: 4
RNN_COMPUTE_DB_LOUDNESS: True
BIDIRECTION: True
# MAPPING_FROM_ECOG: "ECoGMappingTransformer"
ECOG: False #will be overloaded if FINETUNE
SUPLOSS_ON_ECOGF: False # will be overloaded to FIX_GEN if FINETUNE,spec supervise loss only apply to ecog encoder
@@ -78,7 +86,11 @@ MODEL:
N_HEADS : 4
NON_LOCAL: True
# ATTENTION: []
OUTPUT_DIR: training_artifacts/debug_
#OUTPUT_DIR: output/ecog_10241800_lstm1 #training_artifacts/debug
#OUTPUT_DIR: output/ecog_10241800_lstm2
#OUTPUT_DIR: output/ecog_11011800_conv #after change loudness encoder
#OUTPUT_DIR: output/ecog_11011800_lstm1 #after change loudness encoder
OUTPUT_DIR: output/ecog_11021800_lstm1 #after change loudness encoder
# OUTPUT_DIR: training_artifacts/loudnesscomp_han5_ampamploss
# OUTPUT_DIR: training_artifacts/loudnesscomp_han5_ampsynth_masknormed
# OUTPUT_DIR: training_artifacts/debug_f1f2linearmel
@@ -88,9 +100,8 @@ OUTPUT_DIR: training_artifacts/debug_
# OUTPUT_DIR: training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_noprogressive_HBw_ppl_ppld_localreg_ecogf_w_spec_sup
# OUTPUT_DIR: training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_ppl_ppld
# OUTPUT_DIR: training_artifacts/ecog_residual_cycle_attention3264wStyleIN_specchan64_more_attentfeatures_heads4

FINETUNE:
FINETUNE: False
FINETUNE: True
FIX_GEN: True
ENCODER_GUIDE: True
SPECSUP: True
121 changes: 121 additions & 0 deletions configs/ecog_style2_a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Config for training ALAE on FFHQ at resolution 1024x1024

NAME: ecog
DATASET:
PART_COUNT: 16
SIZE: 60000
FFHQ_SOURCE: /data/datasets/ffhq-dataset/tfrecords/ffhq/ffhq-r%02d.tfrecords
PATH: /data/datasets/ffhq-dataset_new/tfrecords/ffhq/splitted/ffhq-r%02d.tfrecords.%03d

FLIP_IMAGES: False

PART_COUNT_TEST: 4
PATH_TEST: /data/datasets/ffhq-dataset_new/tfrecords/ffhq-test/splitted/ffhq-r%02d.tfrecords.%03d

SAMPLES_PATH: ''
STYLE_MIX_PATH: style_mixing/test_images/set_ecog
SPEC_CHANS: 64
TEMPORAL_SAMPLES: 128
BCTS: True
MAX_RESOLUTION_LEVEL: 7
SUBJECT: ['NY742']
MODEL:
#####TAKE OFF CHECKLIST!!!########
N_FORMANTS: 6
N_FORMANTS_NOISE: 1
N_FORMANTS_ECOG: 2
WAVE_BASED : True
DO_MEL_GUIDE : False
BGNOISE_FROMDATA: True
N_FFT : 256
NOISE_DB : -50 #-50
MAX_DB : 22.5 #probablity 28 is better
NOISE_DB_AMP : -25
MAX_DB_AMP : 14
POWER_SYNTH: True

LESS_TEMPORAL_FEATURE: True
LATENT_SPACE_SIZE: 128
LAYER_COUNT: 6
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 16
DLATENT_AVG_BETA: 0.995
MAPPING_LAYERS: 8
TRUNCATIOM_CUTOFF: 5
CHANNELS: 1
UNIQ_WORDS: 50
MAPPING_FROM_ECOG: "ECoGMappingBottleneck"
# MAPPING_FROM_ECOG: "ECoGMappingTransformer"
ECOG: False #will be overloaded if FINETUNE
SUPLOSS_ON_ECOGF: False # will be overloaded to FIX_GEN if FINETUNE,spec supervise loss only apply to ecog encoder
W_SUP: False
GAN: True
GENERATOR: "GeneratorFormant"
ENCODER: "EncoderFormant"
AVERAGE_W: True
TEMPORAL_W: True
GLOBAL_W: True
TEMPORAL_GLOBAL_CAT: True
RESIDUAL: True
W_CLASSIFIER: False
CYCLE: False
ATTENTIONAL_STYLE: True
#T 4 8 16 32 64 128
ATTENTION: [False, False, False, False, False, False]
HEADS: 1
APPLY_PPL: False
APPLY_PPL_D: False
PPL_WEIGHT: 100
PPL_GLOBAL_WEIGHT: 0
PPLD_WEIGHT: 1
PPLD_GLOBAL_WEIGHT: 0
COMMON_Z: True
TRANSFORMER:
HIDDEN_DIM : 256
DIM_FEEDFORWARD : 256
ENCODER_ONLY : False
ATTENTIONAL_MASK : False
N_HEADS : 4
NON_LOCAL: True
# ATTENTION: []
OUTPUT_DIR: output/audio_11021700 #training_artifacts/debug_
# OUTPUT_DIR: training_artifacts/loudnesscomp_han5_ampamploss
# OUTPUT_DIR: training_artifacts/loudnesscomp_han5_ampsynth_masknormed
# OUTPUT_DIR: training_artifacts/debug_f1f2linearmel
# OUTPUT_DIR: training_artifacts/ecog_finetune_3ecogformants_han5_specsup_guidance_hamonicformantsemph
# OUTPUT_DIR: training_artifacts/ecog_finetune_3ecogformants_han5_specsup_guidance_hamonicnoiseformantsemphmore
# OUTPUT_DIR: training_artifacts/formantsythv2_wavebased_NY742_constraintonFB_Bconstrainrefined_absfreq_4formants_1noiseformants_bgnoise_noisemapping_freqconv_duomask
# OUTPUT_DIR: training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_noprogressive_HBw_ppl_ppld_localreg_ecogf_w_spec_sup
# OUTPUT_DIR: training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_ppl_ppld
# OUTPUT_DIR: training_artifacts/ecog_residual_cycle_attention3264wStyleIN_specchan64_more_attentfeatures_heads4

FINETUNE:
FINETUNE: False
FIX_GEN: True
ENCODER_GUIDE: True
SPECSUP: True
#####################################

TRAIN:
PROGRESSIVE: False
W_WEIGHT: 1
CYCLE_WEIGHT: 1
BASE_LEARNING_RATE: 0.002
EPOCHS_PER_LOD: 16
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: [96]
TRAIN_EPOCHS: 60
# 4 8 16 32 64 128 256
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32] # If GPU memory ~16GB reduce last number from 32 to 24
LOD_2_BATCH_4GPU: [64, 64, 64, 64, 32, 16]
LOD_2_BATCH_2GPU: [64, 64, 64, 64, 32, 16]
# LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16]
# LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 32]
# LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16]
# LOD_2_BATCH_1GPU: [128, 128, 128, 128, 64, 32]
# LOD_2_BATCH_1GPU: [512, 256, 256, 128, 64, 16]
LOD_2_BATCH_1GPU: [64, 64, 64, 64, 32, 16]
BATCH_SIZE : 32
# BATCH_SIZE : 2
LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003]
# LEARNING_RATES: [0.0015, 0.0015, 0.0005, 0.0003, 0.0003, 0.0002]
5 changes: 5 additions & 0 deletions defaults.py
Original file line number Diff line number Diff line change
@@ -72,6 +72,11 @@
_C.MODEL.MAPPING_TO_LATENT = "MappingToLatent"
_C.MODEL.MAPPING_FROM_LATENT = "MappingFromLatent"
_C.MODEL.MAPPING_FROM_ECOG = "ECoGMappingDefault"
_C.MODEL.ONEDCONFIRST = True
_C.MODEL.RNN_TYPE = 'LSTM'
_C.MODEL.RNN_LAYERS = 4
_C.MODEL.RNN_COMPUTE_DB_LOUDNESS = True
_C.MODEL.BIDIRECTION = True
_C.MODEL.Z_REGRESSION = False
_C.MODEL.AVERAGE_W = False
_C.MODEL.TEMPORAL_W = False
22 changes: 15 additions & 7 deletions formant_systh.py
Original file line number Diff line number Diff line change
@@ -198,7 +198,9 @@ def subfigure_plot(ax,spec,components,n_mels,which_formant='hamon',formant_line=
if ecog is not None:
comp_ecog = ecog['amplitude_formants_hamon'] if title=='amplitude_hamon' else ecog['amplitude_formants_noise']
plt.sca(ax)
plt.yticks(range(0,200,20),(np.arange(0,1,20)).astype(str))
ax.set_yticks(np.arange(0,200,20))
ax.set_yticklabels(np.arange(0,200,20)/200)
#plt.yticks(range(0,200,20),(np.arange(0,1,20)).astype(str))
ax.imshow(np.clip(1-spec.detach().cpu().numpy().squeeze().T,0,1),vmin=0.0,vmax=1.0)
for i in range(comp.shape[1]):
ax.plot(200*comp[:,i].squeeze().detach().cpu().numpy().T,linewidth=2,color=clrs[i])
@@ -372,10 +374,7 @@ def save_sample(sample,ecog,mask_prior,mni,encoder,decoder,ecog_encoder,epoch,la
fig.savefig(f, bbox_inches='tight',dpi=80)
plt.close(fig)

scipy.io.wavfile.write(f2+'denoisewave.wav',16000,torch.cat(rec_denoise_wave_all.unbind(),1)[0].detach().cpu().numpy())
if ecog_encoder is not None:
scipy.io.wavfile.write(f2+'denoiseecogwave.wav',16000,torch.cat(rec_denoise_ecog_wave_all.unbind(),1)[0].detach().cpu().numpy())


if linear:
rec_all = amplitude(torch.cat((2*rec_all[:,0]-1).unbind(),0).transpose(-2,-1).detach().cpu().numpy(),-50,22.5,trim_noise=False)
rec_wave = spsi(rec_all,(n_fft-1)*2,128)
@@ -393,9 +392,18 @@ def save_sample(sample,ecog,mask_prior,mni,encoder,decoder,ecog_encoder,epoch,la

save_image(resultsample, f2, nrow=resultsample.shape[0]//(2 if ecog_encoder is None else 3))
# import pdb;pdb.set_trace()

if ecog_encoder is not None:
scipy.io.wavfile.write(f2+'denoisewave.wav',16000,torch.cat(rec_denoise_wave_all.unbind(),1)[0].detach().cpu().numpy())
scipy.io.wavfile.write(f2+'denoiseecogwave.wav',16000,torch.cat(rec_denoise_ecog_wave_all.unbind(),1)[0].detach().cpu().numpy())
if mode =='test':
return torch.cat(rec_denoise_ecog_wave_all.unbind(),1)[0].detach().cpu().numpy()
else:
scipy.io.wavfile.write(f2+'denoisewave.wav',16000,torch.cat(rec_denoise_wave_all.unbind(),1)[0].detach().cpu().numpy())
if mode =='test':
return torch.cat(rec_denoise_wave_all.unbind(),1)[0].detach().cpu().numpy()



return

def main():
OUTPUT_DIR = 'training_artifacts/formantsysth_voicingandunvoicing_loudness_NY742'
8 changes: 5 additions & 3 deletions launcher.py
Original file line number Diff line number Diff line change
@@ -111,7 +111,8 @@ def _run(rank, world_size, fn, defaults, write_log, no_cuda, args):
cleanup()


def run(fn, defaults, description='', default_config='configs/experiment.yaml', world_size=1, write_log=False, no_cuda=False):
def run(fn, defaults, description='', default_config='configs/experiment.yaml', world_size=1, write_log=False, no_cuda=False,args=None):
'''
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-c", "--config-file",
@@ -126,13 +127,14 @@ def run(fn, defaults, description='', default_config='configs/experiment.yaml',
default=None,
nargs=argparse.REMAINDER,
)

args = parser.parse_args()
'''
import multiprocessing
cpu_count = multiprocessing.cpu_count()
os.environ["OMP_NUM_THREADS"] = str(max(1, int(cpu_count / world_size)))
del multiprocessing

args = parser.parse_args()


if world_size > 1:
mp.spawn(_run,
320 changes: 274 additions & 46 deletions model_formant.py

Large diffs are not rendered by default.

510 changes: 482 additions & 28 deletions net_formant.py

Large diffs are not rendered by default.

81 changes: 68 additions & 13 deletions train_formant.py → train_formant_a.py
Original file line number Diff line number Diff line change
@@ -38,8 +38,20 @@
from ECoGDataSet import concate_batch
from formant_systh import save_sample

from tensorboardX import SummaryWriter


import argparse

parser = argparse.ArgumentParser(description='Process some integers.')

parser.add_argument('-m', '--modeldir', type=str,default=' ',
help='')
argus = parser.parse_args()


def train(cfg, logger, local_rank, world_size, distributed):
writer = SummaryWriter(cfg.OUTPUT_DIR)
torch.cuda.set_device(local_rank)
model = Model(
generator=cfg.MODEL.GENERATOR,
@@ -191,11 +203,6 @@ def train(cfg, logger, local_rank, world_size, distributed):
save=local_rank == 0)

extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=True,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='./training_artifacts/ecog_finetune_3ecogformants_han5_specsup_guidance_hamonicformantsemph/model_epoch23.pth')
# extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=False,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='./training_artifacts/debug_f1f2linearmel/model_epoch27.pth')
# extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=False,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='./training_artifacts/debug_fitf1f2freqonly/model_epoch28.pth')
# extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=False,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='./training_artifacts/debug_fitf1f2freqonly/model_epoch6.pth')
# extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=True,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='./training_artifacts/loudnesscomp_han5/model_epoch50.pth')
# extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=False,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='./training_artifacts/test_9/model_epoch30.pth')
logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

arguments.update(extra_checkpoint_data)
@@ -262,6 +269,7 @@ def train(cfg, logger, local_rank, world_size, distributed):
# model.eval()
# Lrec = model(sample_spec_test, x_denoise = sample_spec_denoise_test,x_mel = sample_spec_mel_test,ecog=ecog_test if cfg.MODEL.ECOG else None, mask_prior=mask_prior_test if cfg.MODEL.ECOG else None, on_stage = on_stage_test,on_stage_wider = on_stage_wider_test, ae = not cfg.MODEL.ECOG, tracker = tracker_test, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate_test,debug = False,x_amp=sample_spec_amp_test,hamonic_bias = False)
# save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=0,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=True)
n_iter = 0
for epoch in range(cfg.TRAIN.TRAIN_EPOCHS):
model.train()

@@ -271,6 +279,7 @@ def train(cfg, logger, local_rank, world_size, distributed):
epoch_start_time = time.time()
i = 0
for sample_dict_train in tqdm(iter(dataset.iterator)):
n_iter +=1
# import pdb; pdb.set_trace()
# sample_dict_train = concate_batch(sample_dict_train)
i += 1
@@ -310,12 +319,32 @@ def train(cfg, logger, local_rank, world_size, distributed):

if (cfg.MODEL.ECOG):
optimizer.zero_grad()
Lrec = model(x, x_denoise = x_orig_denoise,x_mel = x_mel,ecog=ecog, mask_prior=mask_prior, on_stage = on_stage,on_stage_wider = on_stage_wider, ae = False, tracker = tracker, encoder_guide=cfg.MODEL.W_SUP,duomask=duomask,mni=mni_coordinate,x_amp=x_orig_amp)
Lrec,tracker = model(x, x_denoise = x_orig_denoise,x_mel = x_mel,ecog=ecog, mask_prior=mask_prior, on_stage = on_stage,on_stage_wider = on_stage_wider, ae = False, tracker = tracker, encoder_guide=cfg.MODEL.W_SUP,duomask=duomask,mni=mni_coordinate,x_amp=x_orig_amp)
#print ('tracker',tracker,tracker.tracks)
#for key in ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff']:
# print ('tracker, ',key, tracker.tracks[key].mean(dim=0))
ecog_key_list = ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff']
if n_iter %10==0:
#print ('write tensorboard')
writer.add_scalars('data/loss_group', {key: tracker.tracks[key].mean(dim=0) for key in ecog_key_list}, n_iter)
for key in ecog_key_list:
writer.add_scalar('data/'+key, tracker.tracks[key].mean(dim=0), n_iter)
#pass #log in tensorboard later!!
#tracker Lae_a: 0.0536877, Lae_a_l2: 0.0538647, Lae_db: 0.1655398, Lae_db_l2: 0.1655714, Lloudness: 1.0384552, Lae_denoise: 0.0485827, Lamp: 0.0000148, Lae: 2.1787138, Lexp: -1.9956266, Lf0: 0.0000000, Ldiff: 0.0568467
(Lrec).backward()
optimizer.step()
else:
optimizer.zero_grad()
Lrec = model(x, x_denoise = x_orig_denoise,x_mel = x_mel,ecog=None, mask_prior=None, on_stage = on_stage,on_stage_wider = on_stage_wider, ae = True, tracker = tracker, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate,debug = False,x_amp=x_orig_amp,hamonic_bias = False)#epoch<2)
Lrec,tracker = model(x, x_denoise = x_orig_denoise,x_mel = x_mel,ecog=None, mask_prior=None, on_stage = on_stage,on_stage_wider = on_stage_wider, ae = True, tracker = tracker, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate,debug = False,x_amp=x_orig_amp,hamonic_bias = False)#epoch<2)
#print ('tracker',tracker,tracker.tracks)
#for key in ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff']:
#print ('tracker, ',key, tracker.tracks[key].mean(dim=0))
ecog_key_list = ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff']
if n_iter %10==0:
#print ('write tensorboard')
writer.add_scalars('data/loss_group', {key: tracker.tracks[key].mean(dim=0) for key in ecog_key_list}, n_iter)
for key in ecog_key_list:
writer.add_scalar('data/'+key, tracker.tracks[key].mean(dim=0), n_iter)
(Lrec).backward()
optimizer.step()

@@ -324,18 +353,44 @@ def train(cfg, logger, local_rank, world_size, distributed):

epoch_end_time = time.time()
per_epoch_ptime = epoch_end_time - epoch_start_time

#print ('test save sample')
#save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=sample_spec_amp_test)
#save_sample(x,ecog,mask_prior,mni_coordinate,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=x_orig_denoise,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=labels,mode='train',path=cfg.OUTPUT_DIR,tracker = tracker,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=x_orig_amp)
#print ('finish')

if local_rank == 0:
print(2**(torch.tanh(model.encoder.formant_bandwitdh_slop)))
checkpointer.save("model_epoch%d" % epoch)
model.eval()
Lrec = model(sample_spec_test, x_denoise = sample_spec_denoise_test,x_mel = sample_spec_mel_test,ecog=ecog_test if cfg.MODEL.ECOG else None, mask_prior=mask_prior_test if cfg.MODEL.ECOG else None, on_stage = on_stage_test,on_stage_wider = on_stage_wider_test, ae = not cfg.MODEL.ECOG, tracker = tracker_test, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate_test,debug = False,x_amp=sample_spec_amp_test,hamonic_bias = False)
save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=sample_spec_amp_test)
save_sample(x,ecog,mask_prior,mni_coordinate,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=x_orig_denoise,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=labels,mode='train',path=cfg.OUTPUT_DIR,tracker = tracker,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=x_orig_amp)

if epoch%1==0:
#first mode is test
reconaudio = save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=sample_spec_amp_test)
save_sample(x,ecog,mask_prior,mni_coordinate,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=x_orig_denoise,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=labels,mode='train',path=cfg.OUTPUT_DIR,tracker = tracker,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=x_orig_amp)
writer.add_audio('reconaudio', reconaudio, n_iter, sample_rate=16000)
writer.export_scalars_to_json(cfg.OUTPUT_DIR+"/all_scalars.json")
writer.close()

if __name__ == "__main__":
gpu_count = torch.cuda.device_count()
run(train, get_cfg_defaults(), description='StyleGAN', default_config='configs/ecog_style2.yaml',
world_size=gpu_count)
cfg = get_cfg_defaults()
parser = argparse.ArgumentParser(description='formant')
parser.add_argument(
"-c", "--config-file",
default='configs/ecog_style2_a.yaml',
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()

#if args.modeldir !='':
# cfg.OUTPUT_DIR = args.modeldir
run(train, cfg, description='StyleGAN', default_config='configs/ecog_style2_a.yaml',
world_size=gpu_count,args=args)
427 changes: 427 additions & 0 deletions train_formant_e.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion train_param.json
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
"SelectRegion":["AUDITORY","BROCA","MOTO","SENSORY"],
"BlockRegion":[],
"UseGridOnly":true,
"ReshapeAsGrid":true,
"ReshapeAsGrid":false,
"SeqLen":128,
"DOWN_TF_FS": 125,
"DOWN_ECOG_FS": 125,
1 change: 1 addition & 0 deletions transformer_models/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
88 changes: 88 additions & 0 deletions transformer_models/util/box_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area


def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)


def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)


# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

iou = inter / union
return iou, union


def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)

lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]

return iou - (area - union) / area


def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)

h, w = masks.shape[-2:]

y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)

x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]

y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]

return torch.stack([x_min, y_min, x_max, y_max], 1)
416 changes: 416 additions & 0 deletions transformer_models/util/misc.py

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions transformer_models/util/plot_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Plotting utilities to visualize training logs.
"""
import torch
import pandas as pd
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt


def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0):
dfs = [pd.read_json(Path(p) / 'log.txt', lines=True) for p in logs]

fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))

for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
for j, field in enumerate(fields):
if field == 'mAP':
coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean()
axs[j].plot(coco_eval, c=color)
else:
df.interpolate().ewm(com=ewm_col).mean().plot(
y=[f'train_{field}', f'test_{field}'],
ax=axs[j],
color=[color] * 2,
style=['-', '--']
)
for ax, field in zip(axs, fields):
ax.legend([Path(p).name for p in logs])
ax.set_title(field)


def plot_precision_recall(files, naming_scheme='iter'):
if naming_scheme == 'exp_id':
# name becomes exp_id
names = [f.parts[-3] for f in files]
elif naming_scheme == 'iter':
names = [f.stem for f in files]
else:
raise ValueError(f'not supported {naming_scheme}')
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
data = torch.load(f)
# precision is n_iou, n_points, n_cat, n_area, max_det
precision = data['precision']
recall = data['params'].recThrs
scores = data['scores']
# take precision for all classes, all areas and 100 detections
precision = precision[0, :, :, 0, -1].mean(1)
scores = scores[0, :, :, 0, -1].mean(1)
prec = precision.mean()
rec = data['recall'][0, :, 0, -1].mean()
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
f'score={scores.mean():0.3f}, ' +
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
)
axs[0].plot(recall, precision, c=color)
axs[1].plot(recall, scores, c=color)

axs[0].set_title('Precision / Recall')
axs[0].legend(names)
axs[1].set_title('Scores / Recall')
axs[1].legend(names)
return fig, axs