diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..accf93c5 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index d0a10be2..42526be5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,9 @@ __pycache__/ .idea/ *.pdf -# *.png +*.png *.eps *.txt +*.jpg results*/ + diff --git a/AllSubjectInfo.json b/AllSubjectInfo.json new file mode 100644 index 00000000..a90830d4 --- /dev/null +++ b/AllSubjectInfo.json @@ -0,0 +1,43 @@ +{ + "Shared":{ + "RootPath":"/scratch/akg404-share/ECoGData_Mar_11_20/FunctionalMapping/", + "ORG_WAVE_FS": 24414.1, + "DOWN_WAVE_FS": 16000, + "ORG_ECOG_FS": 3051.7625, + "ORG_ECOG_FS_NY": 512, + "ORG_TF_FS": 125, + "AUDITORY" : ["cSTG","mSTG"], + "BROCA" : ["parstriangularis","parsopercularis"], + "MOTO" : ["precentral"], + "SENSORY" : ["postcentral"] + }, + "Subj":{ + "NY717":{ + "Density":"HB", + "Task":["AudN","SenComp","VisRead","PicN","AudRep"] + }, + "NY742":{ + "Density":"HB", + "Task":["AudN","SenComp","VisRead","PicN","AudRep"] + }, + "NY749":{ + "Density":"HB", + "Task":["AudN","SenComp","VisRead","PicN","AudRep"] + }, + "HD06":{ + "Density":"HD", + "Task":["AudName","AudRep"], + "EventRange":-100 + }, + "HD01":{ + "Density":"HD", + "Task":["AudName","AudRep"], + "BadSamples":[[1,2,3],[1,3]] + } + }, + "BadSamples":{ + "HD01":{ + "AudRep":[1,2,3] + } + } +} \ No newline at end of file diff --git a/ECoGDataSet.py b/ECoGDataSet.py new file mode 100644 index 00000000..508b9c10 --- /dev/null +++ b/ECoGDataSet.py @@ -0,0 +1,976 @@ +import json +import pdb +import torch +import os +import numpy as np +import scipy.io +from scipy import signal +import h5py +import random +import pandas +from torch.utils.data import Dataset +from defaults import get_cfg_defaults +from net_formant import wave2spec +cfg = get_cfg_defaults() +cfg.merge_from_file('configs/ecog_style2.yaml') +BCTS = cfg.DATASET.BCTS +if not cfg.MODEL.POWER_SYNTH: + cfg.MODEL.NOISE_DB = cfg.MODEL.NOISE_DB_AMP + cfg.MODEL.MAX_DB = cfg.MODEL.MAX_DB_AMP + +class ECoGDataset(Dataset): + """docstring for ECoGDataset""" + def zscore(self,ecog,badelec,axis=None): + statics_ecog = np.delete(ecog,badelec,axis=1).mean(axis=axis, keepdims=True)+1e-10,np.delete(ecog,badelec,axis=1).std(axis=axis, keepdims=True)+1e-10 + # statics_ecog = ecog.mean(axis=axis, keepdims=True)+1e-10,ecog.std(axis=axis, keepdims=True)+1e-10 + ecog = (ecog-statics_ecog[0])/statics_ecog[1] + return ecog, statics_ecog + + def rearrange(self,data,crop=None,mode = 'ecog'): + rows = [0,1,2,3,4,5,6,8,9,10,11] + starts = [1,0,1,0,1,0,1,7,6,7,7] + ends = [6,6,6,9,12,14,12,14,14,14,8] + strides = [2,1,2,1,2,1,2,2,1,2,1] + electrodes = [64,67,73,76,85,91,105,111,115,123,127,128] + if mode == 'ecog': + data_new = np.zeros((data.shape[0],15,15)) + data_new[:,::2,::2] = np.reshape(data[:,:64],[-1,8,8]) + for i in range(len(rows)): + data_new[:,rows[i],starts[i]:ends[i]:strides[i]] = data[:,electrodes[i]:electrodes[i+1]] + if crop is None: + return np.reshape(data_new,[data.shape[0],-1]) + else: + return np.reshape(data_new[:,crop[0]:crop[0]+crop[2],crop[1]:crop[1]+crop[3]],[data.shape[0],-1]) # TxN + + elif mode == 'coord': + data_new = np.zeros((15,15,data.shape[-1])) + data_new[::2,::2] = np.reshape(data[:64],[8,8,-1]) + for i in range(len(rows)): + data_new[rows[i],starts[i]:ends[i]:strides[i]] = data[electrodes[i]:electrodes[i+1]] + if crop is None: + return np.reshape(data_new,[-1,data.shape[-1]]) # Nx3 + else: + return np.reshape(data_new[crop[0]:crop[0]+crop[2],crop[1]:crop[1]+crop[3]],[-1,data.shape[-1]]) # Nx3 + + elif mode == 'region': + region_new = np.chararray((15,15),itemsize=100) + region_new[:] = 'nan' + region_new[::2,::2] = np.reshape(data[:64],[8,8]) + for i in range(len(rows)): + region_new[rows[i],starts[i]:ends[i]:strides[i]] = data[electrodes[i]:electrodes[i+1]] + if crop is None: + return np.reshape(region_new,[-1]) + else: + return np.reshape(region_new[crop[0]:crop[0]+crop[2],crop[1]:crop[1]+crop[3]],[-1]) + + elif mode == 'mask': + data_new = np.zeros((15,15)) + data_new[::2,::2] = np.reshape(data[:64],[8,8]) + for i in range(len(rows)): + data_new[rows[i],starts[i]:ends[i]:strides[i]] = data[electrodes[i]:electrodes[i+1]] + if crop is None: + return np.reshape(data_new,[-1]) + else: + return np.reshape(data_new[crop[0]:crop[0]+crop[2],crop[1]:crop[1]+crop[3]],[-1]) + + def select_block(self,ecog,regions,mask,mni_coord,select,block): + if not select and not block: + return ecog,regions,mask,mni_coord + if self.ReshapeAsGrid: + if select: + ecog_ = np.zeros(ecog.shape) + mask_ = np.zeros(mask.shape) + mni_coord_ = np.zeros(mni_coord.shape) + for region in select: + region_ind = [region.encode() == regions[i] for i in range(regions.shape[0])] + ecog_[:,region_ind] = ecog[:,region_ind] + mask_[region_ind] = mask[region_ind] + mni_coord_[region_ind] = mni_coord[region_ind] + return ecog_,regions,mask_,mni_coord_ + if block: + for region in block: + region_ind = [region.encode() == regions[i] for i in range(regions.shape[0])] + ecog[:,region_ind] = 0 + mask[region_ind] = 0 + mni_coord[region_ind]=0 + return ecog,regions,mask,mni_coord + else: + # region_ind = np.ones(regions.shape[0],dtype=bool) + region_ind = np.array([]) + if select: + # region_ind = np.zeros(regions.shape[0],dtype=bool) + for region in select: + region_ind = np.concatenate([region_ind, np.where(np.array([region in regions[i] for i in range(regions.shape[0])]))[0]]) + if block: + # region_ind = np.zeros(regions.shape[0],dtype=bool) + for region in block: + # region_ind = np.logical_or(region_ind, np.array([region in regions[i] for i in range(regions.shape[0])])) + region_ind = np.concatenate([region_ind, np.where(np.array([region in regions[i] for i in range(regions.shape[0])]))[0]]) + # region_ind = np.logical_not(region_ind) + region_ind = np.delete(np.arange(regions.shape[0]),region_ind) + region_ind = region_ind.astype(np.int64) + return ecog[:,region_ind],regions[region_ind],mask[region_ind],mni_coord[region_ind] + def __init__(self, ReqSubjDict, mode = 'train', train_param = None,BCTS=None,world_size=1): + """ ReqSubjDict can be a list of multiple subjects""" + super(ECoGDataset, self).__init__() + self.world_size = world_size + self.current_lod=2 + self.ReqSubjDict = ReqSubjDict + self.mode = mode + self.BCTS = BCTS + self.SpecBands = cfg.DATASET.SPEC_CHANS + with open('AllSubjectInfo.json','r') as rfile: + allsubj_param = json.load(rfile) + if (train_param == None): + with open('train_param.json','r') as rfile: + train_param = json.load(rfile) + + self.rootpath = allsubj_param['Shared']['RootPath'] + self.ORG_WAVE_FS = allsubj_param['Shared']['ORG_WAVE_FS'] + self.ORG_ECOG_FS = allsubj_param['Shared']['ORG_ECOG_FS'] + self.DOWN_WAVE_FS = allsubj_param['Shared']['DOWN_WAVE_FS'] + self.ORG_ECOG_FS_NY = allsubj_param['Shared']['ORG_ECOG_FS_NY'] + self.ORG_TF_FS = allsubj_param['Shared']['ORG_TF_FS'] + self.cortex = {} + self.cortex.update({"AUDITORY" : allsubj_param['Shared']['AUDITORY']}) + self.cortex.update({"BROCA" : allsubj_param['Shared']['BROCA']}) + self.cortex.update({"MOTO" : allsubj_param['Shared']['MOTO']}) + self.cortex.update({"SENSORY" : allsubj_param['Shared']['SENSORY']}) + self.SelectRegion = [] + [self.SelectRegion.extend(self.cortex[area]) for area in train_param["SelectRegion"]] + self.BlockRegion = [] + [self.BlockRegion.extend(self.cortex[area]) for area in train_param["BlockRegion"]] + self.wavebased = cfg.MODEL.WAVE_BASED + self.ReshapeAsGrid = False if ('lstm') or ('Transformer') in cfg.MODEL.MAPPING_FROM_ECOG else True + self.Prod,self.UseGridOnly,self.SeqLen = train_param['Prod'],\ + train_param['UseGridOnly'],\ + train_param['SeqLen'], + self.ahead_onset_test = train_param['Test']['ahead_onset'] + self.ahead_onset_train = train_param['Train']['ahead_onset'] + self.DOWN_TF_FS = train_param['DOWN_TF_FS'] + self.DOWN_ECOG_FS = train_param['DOWN_ECOG_FS'] + self.TestNum_cum=np.array([],dtype=np.int32) + self.Wipenoise = False + + datapath = [] + analysispath = [] + ecog_alldataset = [] + spkr_alldataset = [] + spkr_re_alldataset = [] + spkr_static_alldataset = [] + spkr_re_static_alldataset = [] + start_ind_alldataset = [] + start_ind_valid_alldataset = [] + start_ind_wave_alldataset = [] + start_ind_wave_valid_alldataset = [] + end_ind_alldataset = [] + end_ind_valid_alldataset = [] + end_ind_wave_alldataset = [] + end_ind_wave_valid_alldataset = [] + start_ind_re_alldataset = [] + start_ind_re_valid_alldataset = [] + start_ind_re_wave_alldataset = [] + start_ind_re_wave_valid_alldataset = [] + end_ind_re_alldataset = [] + end_ind_re_valid_alldataset = [] + end_ind_re_wave_alldataset = [] + end_ind_re_wave_valid_alldataset = [] + word_alldataset = [] + label_alldataset = [] + wave_alldataset = [] + wave_re_alldataset = [] + wave_re_spec_alldataset = [] + wave_re_spec_amp_alldataset = [] + wave_re_denoise_alldataset = [] + wave_re_spec_denoise_alldataset = [] + wave_spec_alldataset = [] + noisesample_re_alldataset = [] + noisesample_alldataset = [] + bad_samples_alldataset = [] + baseline_alldataset = [] + mni_coordinate_alldateset = [] + T1_coordinate_alldateset = [] + regions_alldataset =[] + mask_prior_alldataset = [] + dataset_names = [] + ecog_len = [] + unique_labels = [] + # self.ORG_WAVE_FS,self.DOWN_ECOG_FS,self.DOWN_WAVE_FS = allsubj_param['Shared']['ORG_WAVE_FS'],\ + # allsubj_param['Shared']['DOWN_ECOG_FS'],\ + # allsubj_param['Shared']['DOWN_WAVE_FS'],\ + + # spkrdata = h5py.File(DATA_DIR[0][0]+'TF32_16k.mat','r') + # spkr = np.asarray(spkrdata['TFlog']) + # samples_for_statics_ = spkr[statics_samples_spkr[0][0*2]:statics_samples_spkr[0][0*2+1]] + flag_zscore = False + for subj in self.ReqSubjDict: + subj_param = allsubj_param['Subj'][subj] + Density = subj_param['Density'] + Crop = train_param["Subj"][subj]['Crop'] + datapath = os.path.join(self.rootpath,subj,'data') + analysispath = os.path.join(self.rootpath,subj,'analysis') + ecog_ = [] + ecog_len_=[0] + start_ind_train_=[] + end_ind_train_ = [] + end_ind_valid_train_ = [] + start_ind_valid_train_=[] + start_ind_wave_down_train_ =[] + end_ind_wave_down_train_ =[] + start_ind_wave_down_valid_train_ =[] + end_ind_wave_down_valid_train_ =[] + start_ind_re_train_=[] + end_ind_re_train_ = [] + end_ind_re_valid_train_ = [] + start_ind_re_valid_train_=[] + start_ind_re_wave_down_train_ =[] + end_ind_re_wave_down_train_ =[] + start_ind_re_wave_down_valid_train_ =[] + end_ind_re_wave_down_valid_train_ =[] + + start_ind_test_=[] + end_ind_ = [] + end_ind_test_=[] = [] + end_ind_valid_test_ = [] + start_ind_valid_test_=[] + start_ind_wave_down_test_ =[] + end_ind_wave_down_test_ =[] + start_ind_wave_down_valid_test_ =[] + end_ind_wave_down_valid_test_ =[] + start_ind_re_test_=[] + end_ind_re_test_ = [] + end_ind_re_valid_test_ = [] + start_ind_re_valid_test_=[] + start_ind_re_wave_down_test_ =[] + end_ind_re_wave_down_test_ =[] + start_ind_re_wave_down_valid_test_ =[] + end_ind_re_wave_down_valid_test_ =[] + spkr_=[] + wave_=[] + wave_spec_=[] + spkr_re_=[] + wave_re_=[] + noisesample_re_=[] + noisesample_=[] + wave_re_spec_=[] + wave_re_spec_amp_=[] + wave_re_denoise_=[] + wave_re_spec_denoise_=[] + word_train=[] + labels_train=[] + word_test=[] + labels_test=[] + bad_samples_=np.array([]) + self.TestNum_cum = np.append(self.TestNum_cum, np.array(train_param["Subj"][subj]['TestNum']).sum().astype(np.int32)) + for xx,task_to_use in enumerate(train_param["Subj"][subj]['Task']): + self.TestNum = train_param["Subj"][subj]['TestNum'][xx] + # for file in range(len(DATA_DIR)): + HD = True if Density == "HD" else False + datapath_task = os.path.join(datapath,task_to_use) + analysispath_task = os.path.join(analysispath,task_to_use) + # if REPRODFLAG is None: + # self.Prod = True if 'NY' in DATA_DIR[ds][file] and 'share' in DATA_DIR[ds][file] else False + # else: + # self.Prod = REPRODFLAG + print("load data from: ", datapath_task) + ecogdata = h5py.File(os.path.join(datapath_task,'gdat_env.mat'),'r') + ecog = np.asarray(ecogdata['gdat_env']) + # ecog = np.minimum(ecog,data_range_max[ds][file]) + ecog = np.minimum(ecog,30) + event_range = None if "EventRange" not in subj_param.keys() else subj_param["EventRange"] + # bad_samples = [] if "BadSamples" not in subj_param.keys() else subj_param["BadSamples"] + start_ind_wave = scipy.io.loadmat(os.path.join(analysispath_task,'Events.mat'))['Events']['onset'][0] + start_ind_wave = np.asarray([start_ind_wave[i][0,0] for i in range(start_ind_wave.shape[0])])[:event_range] + end_ind_wave = scipy.io.loadmat(os.path.join(analysispath_task,'Events.mat'))['Events']['offset'][0] + end_ind_wave = np.asarray([end_ind_wave[i][0,0] for i in range(end_ind_wave.shape[0])])[:event_range] + + if self.Prod: + start_ind_re_wave = scipy.io.loadmat(os.path.join(analysispath_task,'Events.mat'))['Events']['onset_r'][0] + start_ind_re_wave = np.asarray([start_ind_re_wave[i][0,0] for i in range(start_ind_re_wave.shape[0])])[:event_range] + end_ind_re_wave = scipy.io.loadmat(os.path.join(analysispath_task,'Events.mat'))['Events']['offset_r'][0] + end_ind_re_wave = np.asarray([end_ind_re_wave[i][0,0] for i in range(end_ind_re_wave.shape[0])])[:event_range] + if HD: + start_ind = (start_ind_wave*1.0/self.ORG_WAVE_FS*self.DOWN_ECOG_FS).astype(np.int64) # in ECoG sample + start_ind_wave_down = (start_ind_wave*1.0/self.ORG_WAVE_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind = (end_ind_wave*1.0/self.ORG_WAVE_FS*self.DOWN_ECOG_FS).astype(np.int64) # in ECoG sample + end_ind_wave_down = (end_ind_wave*1.0/self.ORG_WAVE_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_valid = np.delete(start_ind,bad_samples) + end_ind_valid = np.delete(end_ind,bad_samples) + start_ind_wave_down_valid = np.delete(start_ind_wave_down,bad_samples) + end_ind_wave_down_valid = np.delete(end_ind_wave_down,bad_samples) + try: + bad_samples = allsubj_param['BadSamples'][subj][task_to_use] + except: + bad_samples = [] + bad_samples_ = np.concatenate([bad_samples_,np.array(bad_samples)]) + else: + start_ind = (start_ind_wave*1.0/self.ORG_ECOG_FS_NY*self.DOWN_ECOG_FS).astype(np.int64) + start_ind_wave_down = (start_ind_wave*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind = (end_ind_wave*1.0/self.ORG_ECOG_FS_NY*self.DOWN_ECOG_FS).astype(np.int64) + end_ind_wave_down = (end_ind_wave*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + if self.Prod: + bad_samples_HD = scipy.io.loadmat(os.path.join(analysispath_task,'Events.mat'))['Events']['badrsp'][0] + else: + bad_samples_HD = scipy.io.loadmat(os.path.join(analysispath_task,'Events.mat'))['Events']['badevent'][0] + bad_samples_HD = np.asarray([bad_samples_HD[i][0,0] for i in range(bad_samples_HD.shape[0])]) + bad_samples_ = np.concatenate((bad_samples_,bad_samples_HD)) + bad_samples_HD = np.where(np.logical_or(np.logical_or(bad_samples_HD==1, bad_samples_HD==2) , bad_samples_HD==4))[0] + start_ind_valid = np.delete(start_ind,bad_samples_HD) + end_ind_valid = np.delete(end_ind,bad_samples_HD) + start_ind_wave_down_valid = np.delete(start_ind_wave_down,bad_samples_HD) + end_ind_wave_down_valid = np.delete(end_ind_wave_down,bad_samples_HD) + if self.Prod: + start_ind_re = (start_ind_re_wave*1.0/self.ORG_ECOG_FS_NY*self.DOWN_ECOG_FS).astype(np.int64) + start_ind_re_wave_down = (start_ind_re_wave*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_re = (end_ind_re_wave*1.0/self.ORG_ECOG_FS_NY*self.DOWN_ECOG_FS).astype(np.int64) + end_ind_re_wave_down = (end_ind_re_wave*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_re_valid = np.delete(start_ind_re,bad_samples_HD) + end_ind_re_valid = np.delete(end_ind_re,bad_samples_HD) + start_ind_re_wave_down_valid = np.delete(start_ind_re_wave_down,bad_samples_HD) + end_ind_re_wave_down_valid = np.delete(end_ind_re_wave_down,bad_samples_HD) + + + ecog = signal.resample_poly(ecog,self.DOWN_ECOG_FS*10000,30517625,axis=0) if HD else signal.resample_poly(ecog,self.DOWN_ECOG_FS,self.ORG_ECOG_FS_NY,axis=0) # resample to 125 hz + baseline_ind = np.concatenate([np.arange(start_ind_valid[i]-self.DOWN_ECOG_FS//4,start_ind_valid[i]-self.DOWN_ECOG_FS//20) \ + for i in range(len(start_ind_valid))]) #baseline: 1/4 s - 1/20 s before stimulis onset + baseline_ind_spec = np.concatenate([np.arange((start_ind_valid[i]*1.0/self.DOWN_ECOG_FS*self.DOWN_TF_FS-self.DOWN_TF_FS//4).astype(np.int64),(start_ind_valid[i]*1.0/self.DOWN_ECOG_FS*self.DOWN_TF_FS-self.DOWN_TF_FS//8).astype(np.int64)) \ + for i in range(len(start_ind_valid))]) #baseline: 1/4 s - 1/8 s before stimulis onset + baseline = ecog[baseline_ind] + statics_ecog = baseline.mean(axis=0,keepdims=True)+1E-10, np.sqrt(baseline.var(axis=0, keepdims=True))+1E-10 + + ecog = (ecog - statics_ecog[0])/statics_ecog[1] + ecog = np.minimum(ecog,10)#5) + ecog_len_+= [ecog.shape[0]] + ecog_+=[ecog] + + start_ind_train_ += [start_ind[:-self.TestNum] + np.cumsum(ecog_len_)[-2]] + end_ind_train_ += [end_ind[:-self.TestNum] + np.cumsum(ecog_len_)[-2]] + end_ind_valid_train_ += [end_ind_valid[:-self.TestNum] + np.cumsum(ecog_len_)[-2]] + start_ind_valid_train = start_ind_valid[:-self.TestNum] + np.cumsum(ecog_len_)[-2] + start_ind_valid_train_ += [start_ind_valid_train] + start_ind_wave_down_train = start_ind_wave_down[:-self.TestNum] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_wave_down_train_ += [start_ind_wave_down_train] + end_ind_wave_down_train = end_ind_wave_down[:-self.TestNum] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_wave_down_train_ += [end_ind_wave_down_train] + start_ind_wave_down_valid_train = start_ind_wave_down_valid[:-self.TestNum] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_wave_down_valid_train_ += [start_ind_wave_down_valid_train] + end_ind_wave_down_valid_train = end_ind_wave_down_valid[:-self.TestNum] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_wave_down_valid_train_ += [end_ind_wave_down_valid_train] + + start_ind_test_ += [start_ind[-self.TestNum:] + np.cumsum(ecog_len_)[-2]] + end_ind_test_ += [end_ind[-self.TestNum:] + np.cumsum(ecog_len_)[-2]] + end_ind_valid_test_ += [end_ind_valid[-self.TestNum:] + np.cumsum(ecog_len_)[-2]] + start_ind_valid_test = start_ind_valid[-self.TestNum:] + np.cumsum(ecog_len_)[-2] + start_ind_valid_test_ += [start_ind_valid_test] + start_ind_wave_down_test = start_ind_wave_down[-self.TestNum:] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_wave_down_test_ += [start_ind_wave_down_test] + end_ind_wave_down_test = end_ind_wave_down[-self.TestNum:] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_wave_down_test_ += [end_ind_wave_down_test] + start_ind_wave_down_valid_test = start_ind_wave_down_valid[-self.TestNum:] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_wave_down_valid_test_ += [start_ind_wave_down_valid_test] + end_ind_wave_down_valid_test = end_ind_wave_down_valid[-self.TestNum:] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_wave_down_valid_test_ += [end_ind_wave_down_valid_test] + + if self.Prod: + start_ind_re_train_ += [start_ind_re[:-self.TestNum] + np.cumsum(ecog_len_)[-2]] + end_ind_re_train_ += [end_ind_re[:-self.TestNum] + np.cumsum(ecog_len_)[-2]] + end_ind_re_valid_train_ += [end_ind_re_valid[:-self.TestNum] + np.cumsum(ecog_len_)[-2]] + start_ind_re_validtrain_ = start_ind_re_valid[:-self.TestNum] + np.cumsum(ecog_len_)[-2] + start_ind_re_valid_train_ += [start_ind_re_validtrain_] + start_ind_re_wave_downtrain_ = start_ind_re_wave_down[:-self.TestNum] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_re_wave_down_train_ += [start_ind_re_wave_downtrain_] + end_ind_re_wave_downtrain_ = end_ind_re_wave_down[:-self.TestNum] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_re_wave_down_train_ += [end_ind_re_wave_downtrain_] + start_ind_re_wave_down_validtrain_ = start_ind_re_wave_down_valid[:-self.TestNum] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_re_wave_down_valid_train_ += [start_ind_re_wave_down_validtrain_] + end_ind_re_wave_down_validtrain_ = end_ind_re_wave_down_valid[:-self.TestNum] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_re_wave_down_valid_train_ += [end_ind_re_wave_down_validtrain_] + + start_ind_re_test_ += [start_ind_re[-self.TestNum:] + np.cumsum(ecog_len_)[-2]] + end_ind_re_test_ += [end_ind_re[-self.TestNum:] + np.cumsum(ecog_len_)[-2]] + end_ind_re_valid_test_ += [end_ind_re_valid[-self.TestNum:] + np.cumsum(ecog_len_)[-2]] + start_ind_re_validtest_ = start_ind_re_valid[-self.TestNum:] + np.cumsum(ecog_len_)[-2] + start_ind_re_valid_test_ += [start_ind_re_validtest_] + start_ind_re_wave_downtest_ = start_ind_re_wave_down[-self.TestNum:] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_re_wave_down_test_ += [start_ind_re_wave_downtest_] + end_ind_re_wave_downtest_ = end_ind_re_wave_down[-self.TestNum:] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_re_wave_down_test_ += [end_ind_re_wave_downtest_] + start_ind_re_wave_down_validtest_ = start_ind_re_wave_down_valid[-self.TestNum:] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + start_ind_re_wave_down_valid_test_ += [start_ind_re_wave_down_validtest_] + end_ind_re_wave_down_validtest_ = end_ind_re_wave_down_valid[-self.TestNum:] + (np.cumsum(ecog_len_)[-2]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS).astype(np.int64) + end_ind_re_wave_down_valid_test_ += [end_ind_re_wave_down_validtest_] + + if not self.Prod: + spkrdata = h5py.File(os.path.join(datapath_task,'TF32_16k.mat'),'r') + spkr = np.asarray(spkrdata['TFlog']) + spkr = signal.resample(spkr,int(1.0*spkr.shape[0]/self.ORG_TF_FS*self.DOWN_TF_FS),axis=0) + else: + spkr = np.zeros([end_ind[-1],self.SpecBands]) + + samples_for_statics = spkr[start_ind[0]:start_ind[-1]] + # if HD: + # samples_for_statics = samples_for_statics_ + # if not HD: + # samples_for_statics = spkr[start_ind[0]:start_ind[-1]] + if xx==0: + statics_spkr = samples_for_statics.mean(axis=0,keepdims=True)+1E-10, np.sqrt(samples_for_statics.var(axis=0, keepdims=True))+1E-10 + # print(statics_spkr) + if self.Wipenoise: + for samples in range(start_ind.shape[0]): + if not np.isnan(start_ind[samples]): + if samples ==0: + spkr[:start_ind[samples]] = 0 + else: + spkr[end_ind[samples-1]:start_ind[samples]] = 0 + if samples ==start_ind.shape[0]-1: + spkr[end_ind[samples]:] = 0 + spkr = (np.clip(spkr,0.,70.)-35.)/35. + # spkr = (spkr - statics_spkr[0])/statics_spkr[1] + spkr_trim = np.zeros([int(ecog.shape[0]*1.0/self.DOWN_ECOG_FS*self.DOWN_TF_FS),spkr.shape[1]]) + if spkr.shape[0]>spkr_trim.shape[0]: + spkr_trim = spkr[:spkr_trim.shape[0]] + spkr = spkr_trim + else: + spkr_trim[:spkr.shape[0]] = spkr + spkr = spkr_trim + spkr_+=[spkr] + + if not self.Prod: + wavedata = h5py.File(os.path.join(datapath_task,'spkr_16k.mat'),'r') + wavearray = np.asarray(wavedata['spkr']) + wave_trim = np.zeros([int(ecog.shape[0]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS),wavearray.shape[1]]) + else: + wavearray = np.zeros([int(ecog.shape[0]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS),1]) + wave_trim = np.zeros([int(ecog.shape[0]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS),1]) + + if wavearray.shape[0]>wave_trim.shape[0]: + wave_trim = wavearray[:wave_trim.shape[0]] + wavearray = wave_trim + else: + wave_trim[:wavearray.shape[0]] = wavearray + wavearray = wave_trim + wave_+=[wavearray] + if cfg.MODEL.WAVE_BASED: + wave_spec = wave2spec(torch.tensor(wavearray.T),n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB,power=2 if cfg.MODEL.POWER_SYNTH else 1)[0].detach().cpu().numpy() + # wave_spec_amp = wave2spec(torch.tensor(wavearray.T),n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB,to_db=False,power=2 if cfg.MODEL.POWER_SYNTH else 1)[0].detach().cpu().numpy() + noisesample = wave_spec[baseline_ind_spec] + for samples in range(start_ind.shape[0]): + if not np.isnan(start_ind[samples]): + if samples ==0: + wave_spec[:start_ind[samples]] = -1 + else: + wave_spec[end_ind[samples-1]:start_ind[samples]] = -1 + if samples ==start_ind.shape[0]-1: + wave_spec[end_ind[samples]:] = -1 + wave_spec_ +=[wave_spec] + else: + noisesample = spkr[...,baseline_ind_spec] + noisesample_ += [noisesample] + if self.Prod: + # spkr_redata = h5py.File(os.path.join(datapath_task,'TFzoom'+str(self.SpecBands)+'_denoise_16k_wide.mat'),'r') + spkr_redata = h5py.File(os.path.join(datapath_task,'TFzoom'+str(self.SpecBands)+'_denoise_16k_lownoisedb.mat'),'r') + # spkr_redata = h5py.File(os.path.join(datapath_task,'TFzoom'+str(self.SpecBands)+'_denoise_16k.mat'),'r') + # spkr_redata = h5py.File(os.path.join(datapath_task,'TFzoom'+str(self.SpecBands)+'_16k_log10.mat'),'r') + spkr_re = np.asarray(spkr_redata['TFlog']) + spkr_re = signal.resample(spkr_re,int(1.0*spkr_re.shape[0]/self.ORG_TF_FS*self.DOWN_TF_FS),axis=0) + if HD: + samples_for_statics_re = samples_for_statics_re_ + if not HD: + samples_for_statics_re = spkr_re[start_ind_re[0]:start_ind_re[-1]] + # samples_for_statics_re = spkr_re[statics_samples_spkr_re[ds][file*2]:statics_samples_spkr_re[ds][file*2+1]] + if xx==0: + statics_spkr_re = samples_for_statics_re.mean(axis=0,keepdims=True)+1E-10, np.sqrt(samples_for_statics_re.var(axis=0, keepdims=True))+1E-10 + # print(statics_spkr_re) + if self.Wipenoise: + if subj is not "NY717" or (task_to_use is not 'VisRead' and task_to_use is not 'PicN'): + for samples in range(start_ind_re.shape[0]): + if not np.isnan(start_ind_re[samples]): + if samples ==0: + spkr_re[:start_ind_re[samples]] = 0 + else: + spkr_re[end_ind_re[samples-1]:start_ind_re[samples]] = 0 + if samples ==start_ind_re.shape[0]-1: + spkr_re[end_ind_re[samples]:] = 0 + spkr_re = (np.clip(spkr_re,0.,70.)-35.)/35. + # spkr_re = (np.clip(spkr_re,0.,50.)-25.)/25. + # spkr_re = (spkr_re - statics_spkr_re[0])/statics_spkr_re[1] + spkr_re_trim = np.zeros([int(ecog.shape[0]*1.0/self.DOWN_ECOG_FS*self.DOWN_TF_FS),spkr_re.shape[1]]) + if spkr_re.shape[0]>spkr_re_trim.shape[0]: + spkr_re_trim = spkr_re[:spkr_re_trim.shape[0]] + spkr_re = spkr_re_trim + else: + spkr_re_trim[:spkr_re.shape[0]] = spkr_re + spkr_re = spkr_re_trim + spkr_re_+=[spkr_re] + + + # wave_redata = h5py.File(os.path.join(datapath_task,'zoom_denoise_16k.mat'),'r') + wave_redata = h5py.File(os.path.join(datapath_task,'zoom_16k.mat'),'r') + wave_rearray = np.asarray(wave_redata['zoom']) + wave_rearray = wave_rearray.T + wave_re_trim = np.zeros([int(ecog.shape[0]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS),wave_rearray.shape[1]]) + wave_redata_denoise = h5py.File(os.path.join(datapath_task,'zoom_denoise_16k.mat'),'r') + wave_rearray_denoise = np.asarray(wave_redata_denoise['zoom']) + wave_rearray_denoise = wave_rearray_denoise.T + wave_re_trim_denoise = np.zeros([int(ecog.shape[0]*1.0/self.DOWN_ECOG_FS*self.DOWN_WAVE_FS),wave_rearray_denoise.shape[1]]) + if wave_rearray.shape[0]>wave_re_trim.shape[0]: + wave_re_trim = wave_rearray[:wave_re_trim.shape[0]] + wave_rearray = wave_re_trim + wave_re_trim_denoise = wave_rearray_denoise[:wave_re_trim_denoise.shape[0]] + wave_rearray_denoise = wave_re_trim_denoise + + else: + wave_re_trim[:wave_rearray.shape[0]] = wave_rearray + wave_rearray = wave_re_trim + wave_re_trim_denoise[:wave_rearray_denoise.shape[0]] = wave_rearray_denoise + wave_rearray_denoise = wave_re_trim_denoise + wave_re_+=[wave_rearray] + wave_re_denoise_+=[wave_rearray_denoise] + + if cfg.MODEL.WAVE_BASED: + wave_re_spec = wave2spec(torch.tensor(wave_rearray.T),n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB,power=2 if cfg.MODEL.POWER_SYNTH else 1)[0].detach().cpu().numpy() + wave_re_spec_amp = wave2spec(torch.tensor(wave_rearray.T),n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB,to_db=False,power=2 if cfg.MODEL.POWER_SYNTH else 1)[0].detach().cpu().numpy() + wave_re_spec_denoise = wave2spec(torch.tensor(wave_rearray_denoise.T),n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB,power=2 if cfg.MODEL.POWER_SYNTH else 1)[0].detach().cpu().numpy() + noisesample_re = wave_re_spec[baseline_ind_spec] + if self.Wipenoise: + if subj is not "NY717" or (task_to_use is not 'VisRead' and task_to_use is not 'PicN'): + for samples in range(start_ind_re.shape[0]): + if not np.isnan(start_ind_re[samples]): + if samples ==0: + wave_re_spec[:start_ind_re[samples]] = -1 + wave_re_spec_amp[:start_ind_re[samples]] = 0 + wave_re_spec_denoise[:start_ind_re[samples]] = -1 + else: + wave_re_spec[end_ind_re[samples-1]:start_ind_re[samples]] = -1 + wave_re_spec_amp[end_ind_re[samples-1]:start_ind_re[samples]] = 0 + wave_re_spec_denoise[end_ind_re[samples-1]:start_ind_re[samples]] = -1 + if samples ==start_ind_re.shape[0]-1: + wave_re_spec[end_ind_re[samples]:] = -1 + wave_re_spec_amp[end_ind_re[samples]:] = 0 + wave_re_spec_denoise[end_ind_re[samples]:] = -1 + wave_re_spec_ +=[wave_re_spec] + wave_re_spec_amp_ +=[wave_re_spec_amp] + wave_re_spec_denoise_ +=[wave_re_spec_denoise] + else: + noisesample_re = spkr_re[...,baseline_ind_spec] + noisesample_re_ += [noisesample_re] + + + + if HD: + label_mat = scipy.io.loadmat(os.path.join(analysispath_task,'Events.mat'))['Events']['word'][0][:event_range] + else: + label_mat = scipy.io.loadmat(os.path.join(analysispath_task,'Events.mat'))['Events']['correctrsp'][0][:event_range] + label_subset = [] + label_mat = np.delete(label_mat,bad_samples_HD) + for i in range(label_mat.shape[0]): + if HD: + label_mati = label_mat[i][0] + else: + label_mati = label_mat[i][0][0][0].lower() + # labels.append(str(label_mati).replace('.wav','')) + label_subset.append(label_mati) + if label_mati not in unique_labels: + unique_labels.append(label_mati) + label_ind = np.zeros([label_mat.shape[0]]) + for i in range(label_mat.shape[0]): + label_ind[i] = unique_labels.index(label_subset[i]) + label_ind = np.asarray(label_ind,dtype=np.int16) + word_train+=[label_ind[:-self.TestNum]] + labels_train+=[label_subset[:-self.TestNum]] + word_test+=[label_ind[-self.TestNum:]] + labels_test+=[label_subset[-self.TestNum:]] + + ################ clean ################## + if not HD: + # bad_samples_ = np.where(bad_samples_==1)[0] + bad_samples_ = np.where(np.logical_or(np.logical_or(bad_samples_==1, bad_samples_==2) , bad_samples_==4))[0] + if HD: + bad_channels = np.array([]) if "BadElec" not in subj_param.keys() else subj_param["BadElec"] + else: + bad_channels = scipy.io.loadmat(os.path.join(analysispath_task,'subj_globals.mat'))['bad_elecs'][0]-1 + # dataset_name = [name for name in DATA_DIR[ds][0].split('/') if 'NY' in name or 'HD' in name] + if HD: + mni_coord = np.array([]) + T1_coord = np.array([]) + else: + csvfile = os.path.join(analysispath,'coordinates.csv') + coord = pandas.read_csv(csvfile) + mni_coord = np.stack([np.array(coord['MNI_x'][:128]),np.array(coord['MNI_y'][:128]),np.array(coord['MNI_z'][:128])],axis=1) + # mni_coord = rearrange(mni_coord,Crop,mode = 'coord') + mni_coord = mni_coord.astype(np.float32) + mni_coord = (mni_coord-np.array([-74.,-23.,-20.]))*2/np.array([74.,46.,54.])-1 + T1_coord = np.stack([np.array(coord['T1_x'][:128]),np.array(coord['T1_y'][:128]),np.array(coord['T1_z'][:128])],axis=1) + # T1_coord = rearrange(T1_coord,NY_crop[ds],mode = 'coord') + T1_coord = T1_coord.astype(np.float32) + T1_coord = (T1_coord-np.array([-74.,-23.,-20.]))*2/np.array([74.,46.,54.])-1 + # for i in range(mni_coord.shape[0]): + # print(i,' ',mni_coord[i]) + percent1 = np.array([float(coord['AR_Percentage'][i].strip("%").strip())/100.0 for i in range(128)]) + percent2 = np.array([0.0 if isinstance(coord['AR_7'][i],float) else float(coord['AR_7'][i].strip("%").strip())/100.0 for i in range(128)]) + percent = np.stack([percent1,percent2],1) + AR1 = np.array([coord['T1_AnatomicalRegion'][i] for i in range(128)]) + AR2 = np.array([coord['AR_8'][i] for i in range(128)]) + AR = np.stack([AR1,AR2],1) + regions = np.array([AR[i,np.argmax(percent,1)[i]] for i in range(AR.shape[0])]) + mask = np.ones(ecog_[0].shape[1]) + mask[bad_channels] = 0. + lastchannel = ecog_[0].shape[1] if not self.UseGridOnly else (128 if Density=="HB" else 64) + if self.ReshapeAsGrid: + regions = self.rearrange(regions,Crop,mode = 'region') + mask = self.rearrange(mask,Crop,mode = 'mask') + mni_coord = self.rearrange(mni_coord,Crop,mode = 'coord') + else: + mask = mask if HD else mask[:lastchannel] + regions = regions if HD else regions[:lastchannel] + mni_coord = mni_coord if HD else mni_coord[:lastchannel] + + + + ecog_ = np.concatenate(ecog_,axis=0) + ecog_ = ecog_ if HD else ecog_[:,:lastchannel] + # start_ind_valid_ = np.concatenate(start_ind_valid_,axis=0) + if HD: + ecog_,statics_ecog_zscore = self.zscore(ecog_,badelec = bad_channels) + elif not flag_zscore: + ecog_,statics_ecog_zscore = self.zscore(ecog_,badelec = bad_channels) + flag_zscore = True + else: + ecog_ = (ecog_-statics_ecog_zscore[0])/statics_ecog_zscore[1] + if bad_channels.size !=0: # if bad_channels is not empty + ecog_[:,bad_channels[bad_channels **Adversarial Latent Autoencoders**
> Stanislav Pidhorskyi, Donald Adjeroh, Gianfranco Doretto
> -> **Abstract:** *Autoencoder networks are unsupervised approaches aiming at combining generative and representational properties by learning simultaneously an encoder-generator map. Although studied extensively, the issues of whether they have the same generative power of GANs, or learn disentangled representations, have not been fully addressed. We introduce an autoencoder that tackles these issues jointly, which we call Adversarial Latent Autoencoder (ALAE). It is a general architecture that can leverage recent improvements on GAN training procedures. We designed two autoencoders: one based on a MLP encoder, and another based on a StyleGAN generator, which we call StyleALAE. We verify the disentanglement properties of both architectures. We show that StyleALAE can not only generate 1024x1024 face images with comparable quality of StyleGAN, but at the same resolution can also produce face reconstructions and manipulations based on real images. This makes ALAE the first autoencoder able to compare with, and go beyond the capabilities of a generator-only type of architecture.* +> **Abstract:** *Autoencoder networks are unsupervised approaches aiming at combining generative and representational properties by learning simultaneously an encoder-generator map. Although studied extensively, the issues of whether they have the same generative power of GANs, or learn disentangled representations, have not been fully addressed. We introduce an autoencoder that tackles these issues jointly, which we call Adversarial Latent Autoencoder (ALAE). It is a general architecture that can leverage recent improvements on GAN training procedures. We designed two autoencoders: one based on a MLP encoder, and another based on a StyleGAN generator, which we call StyleALAE. We verify the disentanglement properties of both architectures. We show that StyleALAE can not only generate 1024x1024 face images with comparable quality of StyleGAN, but at the same resolution can also produce face reconstructions and manipulations based on real images. This makes ALAE the first autoencoder able to compare with, and go beyond, the capabilities of a generator-only type of architecture.* ## Citation * Stanislav Pidhorskyi, Donald A. Adjeroh, and Gianfranco Doretto. Adversarial Latent Autoencoders. In *Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR)*, 2020. [to appear] diff --git a/checkpointer.py b/checkpointer.py index 61e79b83..c8fc8d4f 100644 --- a/checkpointer.py +++ b/checkpointer.py @@ -66,11 +66,12 @@ def save_data(): return save_data() - def load(self, ignore_last_checkpoint=False, file_name=None): + def load(self, ignore_last_checkpoint=False, ignore_auxiliary=False,file_name=None): save_file = os.path.join(self.cfg.OUTPUT_DIR, "last_checkpoint") try: with open(save_file, "r") as last_checkpoint: f = last_checkpoint.read().strip() + f = os.path.join(self.cfg.OUTPUT_DIR, f) except IOError: self.logger.info("No checkpoint found. Initializing model from scratch") if file_name is None: @@ -81,7 +82,6 @@ def load(self, ignore_last_checkpoint=False, file_name=None): return {} if file_name is not None: f = file_name - self.logger.info("Loading checkpoint from {}".format(f)) checkpoint = torch.load(f, map_location=torch.device("cpu")) for name, model in self.models.items(): @@ -98,7 +98,8 @@ def load(self, ignore_last_checkpoint=False, file_name=None): else: self.logger.warning("No state dict for model: %s" % name) checkpoint.pop('models') - if "auxiliary" in checkpoint and self.auxiliary: + + if "auxiliary" in checkpoint and self.auxiliary and not ignore_auxiliary: self.logger.info("Loading auxiliary from {}".format(f)) for name, item in self.auxiliary.items(): try: diff --git a/configs/ecog.yaml b/configs/ecog.yaml new file mode 100644 index 00000000..07bcb587 --- /dev/null +++ b/configs/ecog.yaml @@ -0,0 +1,66 @@ + # Config for training ALAE on FFHQ at resolution 1024x1024 + +NAME: ecog +DATASET: + PART_COUNT: 16 + SIZE: 60000 + FFHQ_SOURCE: /data/datasets/ffhq-dataset/tfrecords/ffhq/ffhq-r%02d.tfrecords + PATH: /data/datasets/ffhq-dataset_new/tfrecords/ffhq/splitted/ffhq-r%02d.tfrecords.%03d + + FLIP_IMAGES: False + + PART_COUNT_TEST: 2 + PATH_TEST: /data/datasets/ffhq-dataset_new/tfrecords/ffhq-test/splitted/ffhq-r%02d.tfrecords.%03d + + SAMPLES_PATH: '' + STYLE_MIX_PATH: style_mixing/test_images/set_ecog + SPEC_CHANS: 64 + TEMPORAL_SAMPLES: 128 + BCTS: True + MAX_RESOLUTION_LEVEL: 7 +MODEL: + LATENT_SPACE_SIZE: 128 + LAYER_COUNT: 6 + MAX_CHANNEL_COUNT: 512 + START_CHANNEL_COUNT: 16 + DLATENT_AVG_BETA: 0.995 + MAPPING_LAYERS: 8 + TRUNCATIOM_CUTOFF: 5 + CHANNELS: 1 + UNIQ_WORDS: 50 + #####TAKE OFF CHECKLIST!!!######## + AVERAGE_W: False + TEMPORAL_W: False + RESIDUAL: True + W_CLASSIFIER: False + CYCLE: True + ATTENTIONAL_STYLE: False + #T 4 8 16 32 64 128 + ATTENTION: [False, False, False, False, True, True] + HEADS: 1 + # ATTENTION: [] +OUTPUT_DIR: training_artifacts/vis +# OUTPUT_DIR: training_artifacts/ecog_residual_cycle_attention3264wStyleIN_specchan64_more_attentfeatures_heads4 +##################################### + +TRAIN: + W_WEIGHT: 1 + CYCLE_WEIGHT: 1 + BASE_LEARNING_RATE: 0.002 + EPOCHS_PER_LOD: 16 + LEARNING_DECAY_RATE: 0.1 + LEARNING_DECAY_STEPS: [96] + TRAIN_EPOCHS: 112 + # 4 8 16 32 64 128 256 + LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32] # If GPU memory ~16GB reduce last number from 32 to 24 + LOD_2_BATCH_4GPU: [64, 64, 64, 64, 32, 16] + LOD_2_BATCH_2GPU: [64, 64, 64, 64, 32, 8] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 32] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16] + # LOD_2_BATCH_1GPU: [128, 128, 128, 128, 64, 32] + # LOD_2_BATCH_1GPU: [512, 256, 256, 128, 64, 16] + LOD_2_BATCH_1GPU: [64, 64, 64, 64, 32, 16] + + LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003] + # LEARNING_RATES: [0.0015, 0.0015, 0.0005, 0.0003, 0.0003, 0.0002] diff --git a/configs/ecog_style2.yaml b/configs/ecog_style2.yaml new file mode 100644 index 00000000..8c085b96 --- /dev/null +++ b/configs/ecog_style2.yaml @@ -0,0 +1,132 @@ + # Config for training ALAE on FFHQ at resolution 1024x1024 + +NAME: ecog +DATASET: + PART_COUNT: 16 + SIZE: 60000 + FFHQ_SOURCE: /data/datasets/ffhq-dataset/tfrecords/ffhq/ffhq-r%02d.tfrecords + PATH: /data/datasets/ffhq-dataset_new/tfrecords/ffhq/splitted/ffhq-r%02d.tfrecords.%03d + + FLIP_IMAGES: False + + PART_COUNT_TEST: 4 + PATH_TEST: /data/datasets/ffhq-dataset_new/tfrecords/ffhq-test/splitted/ffhq-r%02d.tfrecords.%03d + + SAMPLES_PATH: '' + STYLE_MIX_PATH: style_mixing/test_images/set_ecog + SPEC_CHANS: 64 + TEMPORAL_SAMPLES: 128 + BCTS: True + MAX_RESOLUTION_LEVEL: 7 + SUBJECT: ['NY742'] +MODEL: + #####TAKE OFF CHECKLIST!!!######## + N_FORMANTS: 6 + N_FORMANTS_NOISE: 1 + N_FORMANTS_ECOG: 6 + WAVE_BASED : True + DO_MEL_GUIDE : False + BGNOISE_FROMDATA: True + N_FFT : 256 + NOISE_DB : -50 #-50 + MAX_DB : 22.5 #probablity 28 is better + NOISE_DB_AMP : -25 + MAX_DB_AMP : 14 + POWER_SYNTH: True + + LESS_TEMPORAL_FEATURE: True + LATENT_SPACE_SIZE: 128 + LAYER_COUNT: 6 + MAX_CHANNEL_COUNT: 512 + START_CHANNEL_COUNT: 16 + DLATENT_AVG_BETA: 0.995 + MAPPING_LAYERS: 8 + TRUNCATIOM_CUTOFF: 5 + CHANNELS: 1 + UNIQ_WORDS: 50 + #MAPPING_FROM_ECOG: "ECoGMappingBottleneck" #ECoGMappingBottlenecklstm1, ECoGMappingBottlenecklstm2 + #MAPPING_FROM_ECOG: "ECoGMappingBottlenecklstm1" + #MAPPING_FROM_ECOG: "ECoGMappingBottlenecklstm" + MAPPING_FROM_ECOG: "ECoGMappingBottlenecklstm_pure" + ONEDCONFIRST: True + RNN_TYPE: 'LSTM' + RNN_LAYERS: 4 + RNN_COMPUTE_DB_LOUDNESS: True + BIDIRECTION: True + # MAPPING_FROM_ECOG: "ECoGMappingTransformer" + ECOG: False #will be overloaded if FINETUNE + SUPLOSS_ON_ECOGF: False # will be overloaded to FIX_GEN if FINETUNE,spec supervise loss only apply to ecog encoder + W_SUP: False + GAN: True + GENERATOR: "GeneratorFormant" + ENCODER: "EncoderFormant" + AVERAGE_W: True + TEMPORAL_W: True + GLOBAL_W: True + TEMPORAL_GLOBAL_CAT: True + RESIDUAL: True + W_CLASSIFIER: False + CYCLE: False + ATTENTIONAL_STYLE: True + #T 4 8 16 32 64 128 + ATTENTION: [False, False, False, False, False, False] + HEADS: 1 + APPLY_PPL: False + APPLY_PPL_D: False + PPL_WEIGHT: 100 + PPL_GLOBAL_WEIGHT: 0 + PPLD_WEIGHT: 1 + PPLD_GLOBAL_WEIGHT: 0 + COMMON_Z: True + TRANSFORMER: + HIDDEN_DIM : 256 + DIM_FEEDFORWARD : 256 + ENCODER_ONLY : False + ATTENTIONAL_MASK : False + N_HEADS : 4 + NON_LOCAL: True + # ATTENTION: [] +#OUTPUT_DIR: output/ecog_10241800_lstm1 #training_artifacts/debug +#OUTPUT_DIR: output/ecog_10241800_lstm2 +#OUTPUT_DIR: output/ecog_11011800_conv #after change loudness encoder +#OUTPUT_DIR: output/ecog_11011800_lstm1 #after change loudness encoder +OUTPUT_DIR: output/ecog_11021800_lstm1 #after change loudness encoder +# OUTPUT_DIR: training_artifacts/loudnesscomp_han5_ampamploss +# OUTPUT_DIR: training_artifacts/loudnesscomp_han5_ampsynth_masknormed +# OUTPUT_DIR: training_artifacts/debug_f1f2linearmel +# OUTPUT_DIR: training_artifacts/ecog_finetune_3ecogformants_han5_specsup_guidance_hamonicformantsemph +# OUTPUT_DIR: training_artifacts/ecog_finetune_3ecogformants_han5_specsup_guidance_hamonicnoiseformantsemphmore +# OUTPUT_DIR: training_artifacts/formantsythv2_wavebased_NY742_constraintonFB_Bconstrainrefined_absfreq_4formants_1noiseformants_bgnoise_noisemapping_freqconv_duomask +# OUTPUT_DIR: training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_noprogressive_HBw_ppl_ppld_localreg_ecogf_w_spec_sup +# OUTPUT_DIR: training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_ppl_ppld +# OUTPUT_DIR: training_artifacts/ecog_residual_cycle_attention3264wStyleIN_specchan64_more_attentfeatures_heads4 +FINETUNE: + FINETUNE: True + FIX_GEN: True + ENCODER_GUIDE: True + SPECSUP: True +##################################### + +TRAIN: + PROGRESSIVE: False + W_WEIGHT: 1 + CYCLE_WEIGHT: 1 + BASE_LEARNING_RATE: 0.002 + EPOCHS_PER_LOD: 16 + LEARNING_DECAY_RATE: 0.1 + LEARNING_DECAY_STEPS: [96] + TRAIN_EPOCHS: 60 + # 4 8 16 32 64 128 256 + LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32] # If GPU memory ~16GB reduce last number from 32 to 24 + LOD_2_BATCH_4GPU: [64, 64, 64, 64, 32, 16] + LOD_2_BATCH_2GPU: [64, 64, 64, 64, 32, 16] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 32] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16] + # LOD_2_BATCH_1GPU: [128, 128, 128, 128, 64, 32] + # LOD_2_BATCH_1GPU: [512, 256, 256, 128, 64, 16] + LOD_2_BATCH_1GPU: [64, 64, 64, 64, 32, 16] + BATCH_SIZE : 32 + # BATCH_SIZE : 2 + LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003] + # LEARNING_RATES: [0.0015, 0.0015, 0.0005, 0.0003, 0.0003, 0.0002] diff --git a/configs/ecog_style2_a.yaml b/configs/ecog_style2_a.yaml new file mode 100644 index 00000000..d16a3a5f --- /dev/null +++ b/configs/ecog_style2_a.yaml @@ -0,0 +1,121 @@ + # Config for training ALAE on FFHQ at resolution 1024x1024 + +NAME: ecog +DATASET: + PART_COUNT: 16 + SIZE: 60000 + FFHQ_SOURCE: /data/datasets/ffhq-dataset/tfrecords/ffhq/ffhq-r%02d.tfrecords + PATH: /data/datasets/ffhq-dataset_new/tfrecords/ffhq/splitted/ffhq-r%02d.tfrecords.%03d + + FLIP_IMAGES: False + + PART_COUNT_TEST: 4 + PATH_TEST: /data/datasets/ffhq-dataset_new/tfrecords/ffhq-test/splitted/ffhq-r%02d.tfrecords.%03d + + SAMPLES_PATH: '' + STYLE_MIX_PATH: style_mixing/test_images/set_ecog + SPEC_CHANS: 64 + TEMPORAL_SAMPLES: 128 + BCTS: True + MAX_RESOLUTION_LEVEL: 7 + SUBJECT: ['NY742'] +MODEL: + #####TAKE OFF CHECKLIST!!!######## + N_FORMANTS: 6 + N_FORMANTS_NOISE: 1 + N_FORMANTS_ECOG: 2 + WAVE_BASED : True + DO_MEL_GUIDE : False + BGNOISE_FROMDATA: True + N_FFT : 256 + NOISE_DB : -50 #-50 + MAX_DB : 22.5 #probablity 28 is better + NOISE_DB_AMP : -25 + MAX_DB_AMP : 14 + POWER_SYNTH: True + + LESS_TEMPORAL_FEATURE: True + LATENT_SPACE_SIZE: 128 + LAYER_COUNT: 6 + MAX_CHANNEL_COUNT: 512 + START_CHANNEL_COUNT: 16 + DLATENT_AVG_BETA: 0.995 + MAPPING_LAYERS: 8 + TRUNCATIOM_CUTOFF: 5 + CHANNELS: 1 + UNIQ_WORDS: 50 + MAPPING_FROM_ECOG: "ECoGMappingBottleneck" + # MAPPING_FROM_ECOG: "ECoGMappingTransformer" + ECOG: False #will be overloaded if FINETUNE + SUPLOSS_ON_ECOGF: False # will be overloaded to FIX_GEN if FINETUNE,spec supervise loss only apply to ecog encoder + W_SUP: False + GAN: True + GENERATOR: "GeneratorFormant" + ENCODER: "EncoderFormant" + AVERAGE_W: True + TEMPORAL_W: True + GLOBAL_W: True + TEMPORAL_GLOBAL_CAT: True + RESIDUAL: True + W_CLASSIFIER: False + CYCLE: False + ATTENTIONAL_STYLE: True + #T 4 8 16 32 64 128 + ATTENTION: [False, False, False, False, False, False] + HEADS: 1 + APPLY_PPL: False + APPLY_PPL_D: False + PPL_WEIGHT: 100 + PPL_GLOBAL_WEIGHT: 0 + PPLD_WEIGHT: 1 + PPLD_GLOBAL_WEIGHT: 0 + COMMON_Z: True + TRANSFORMER: + HIDDEN_DIM : 256 + DIM_FEEDFORWARD : 256 + ENCODER_ONLY : False + ATTENTIONAL_MASK : False + N_HEADS : 4 + NON_LOCAL: True + # ATTENTION: [] +OUTPUT_DIR: output/audio_11021700 #training_artifacts/debug_ +# OUTPUT_DIR: training_artifacts/loudnesscomp_han5_ampamploss +# OUTPUT_DIR: training_artifacts/loudnesscomp_han5_ampsynth_masknormed +# OUTPUT_DIR: training_artifacts/debug_f1f2linearmel +# OUTPUT_DIR: training_artifacts/ecog_finetune_3ecogformants_han5_specsup_guidance_hamonicformantsemph +# OUTPUT_DIR: training_artifacts/ecog_finetune_3ecogformants_han5_specsup_guidance_hamonicnoiseformantsemphmore +# OUTPUT_DIR: training_artifacts/formantsythv2_wavebased_NY742_constraintonFB_Bconstrainrefined_absfreq_4formants_1noiseformants_bgnoise_noisemapping_freqconv_duomask +# OUTPUT_DIR: training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_noprogressive_HBw_ppl_ppld_localreg_ecogf_w_spec_sup +# OUTPUT_DIR: training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_ppl_ppld +# OUTPUT_DIR: training_artifacts/ecog_residual_cycle_attention3264wStyleIN_specchan64_more_attentfeatures_heads4 + +FINETUNE: + FINETUNE: False + FIX_GEN: True + ENCODER_GUIDE: True + SPECSUP: True +##################################### + +TRAIN: + PROGRESSIVE: False + W_WEIGHT: 1 + CYCLE_WEIGHT: 1 + BASE_LEARNING_RATE: 0.002 + EPOCHS_PER_LOD: 16 + LEARNING_DECAY_RATE: 0.1 + LEARNING_DECAY_STEPS: [96] + TRAIN_EPOCHS: 60 + # 4 8 16 32 64 128 256 + LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32] # If GPU memory ~16GB reduce last number from 32 to 24 + LOD_2_BATCH_4GPU: [64, 64, 64, 64, 32, 16] + LOD_2_BATCH_2GPU: [64, 64, 64, 64, 32, 16] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 32] + # LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16] + # LOD_2_BATCH_1GPU: [128, 128, 128, 128, 64, 32] + # LOD_2_BATCH_1GPU: [512, 256, 256, 128, 64, 16] + LOD_2_BATCH_1GPU: [64, 64, 64, 64, 32, 16] + BATCH_SIZE : 32 + # BATCH_SIZE : 2 + LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003] + # LEARNING_RATES: [0.0015, 0.0015, 0.0005, 0.0003, 0.0003, 0.0002] diff --git a/dataloader_ecog.py b/dataloader_ecog.py new file mode 100644 index 00000000..2ffac8fc --- /dev/null +++ b/dataloader_ecog.py @@ -0,0 +1,320 @@ +# Copyright 2019-2020 Stanislav Pidhorskyi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import dareblopy as db +import random + +import numpy as np +import torch +import torch.tensor +import torch.utils +import torch.utils.data +import time +import math +from ECoGDataSet import ECoGDataset +cpu = torch.device('cpu') + +# class myDataLoader(torch.utils.data.DataLoader): +# def __init__(self,dataset,batch_size, shuffle,drop_last): +# super(myDataLoader).__init__() + + +class TFRecordsDataset: + def __init__(self, cfg, logger, rank=0, world_size=1, buffer_size_mb=200, channels=3, seed=None, train=True, needs_labels=False,param=None): + self.param = param + self.dataset = ECoGDataset(cfg.DATASET.SUBJECT,mode='train' if train else 'test',world_size=world_size) + self.noise_dist = self.dataset.meta_data['noisesample_re_alldataset'][0] + self.cfg = cfg + self.logger = logger + self.rank = rank + self.last_data = "" + if train: + self.part_count = cfg.DATASET.PART_COUNT + self.part_size = cfg.DATASET.SIZE // self.part_count + else: + self.part_count = cfg.DATASET.PART_COUNT_TEST + self.part_size = cfg.DATASET.SIZE_TEST // self.part_count + self.workers = [] + self.workers_active = 0 + self.iterator = None + self.filenames = {} + self.batch_size = cfg.TRAIN.BATCH_SIZE//world_size if train else len(self.dataset) + self.features = {} + self.channels = channels + self.seed = seed + self.train = train + self.needs_labels = needs_labels + + assert self.part_count % world_size == 0 + + self.part_count_local = self.part_count // world_size + + if train: + path = cfg.DATASET.PATH + else: + path = cfg.DATASET.PATH_TEST + + for r in range(2, cfg.DATASET.MAX_RESOLUTION_LEVEL + 1): + files = [] + for i in range(self.part_count_local * rank, self.part_count_local * (rank + 1)): + file = path % (r, i) + files.append(file) + self.filenames[r] = files + + self.buffer_size_b = 1024 ** 2 * buffer_size_mb + + self.current_filenames = [] + self.iterator = torch.utils.data.DataLoader(self.dataset, + batch_size=self.batch_size, + shuffle=True if self.train else False, + drop_last=True if self.train else False) + def reset(self, lod, batch_size): + assert lod in self.filenames.keys() + self.current_filenames = self.filenames[lod] + if batch_size!=self.batch_size: + self.iterator = torch.utils.data.DataLoader(self.dataset, + batch_size=batch_size, + shuffle=True if self.train else False, + drop_last=True if self.train else False) + self.batch_size = batch_size + self.dataset.current_lod=lod + + # self.dataset.img_size = 2 ** lod + + # if self.needs_labels: + # self.features = { + # # 'shape': db.FixedLenFeature([3], db.int64), + # 'data': db.FixedLenFeature([self.channels, img_size, img_size], db.uint8), + # 'label': db.FixedLenFeature([], db.int64) + # } + # else: + # self.features = { + # # 'shape': db.FixedLenFeature([3], db.int64), + # 'data': db.FixedLenFeature([self.channels, img_size, img_size], db.uint8) + # } + + # buffer_size = self.buffer_size_b // (self.channels * img_size * img_size) + + # if self.seed is None: + # seed = np.uint64(time.time() * 1000) + # else: + # seed = self.seed + # self.logger.info('!' * 80) + # self.logger.info('! Seed is used for to shuffle data in TFRecordsDataset!') + # self.logger.info('!' * 80) + + # self.iterator = db.ParsedTFRecordsDatasetIterator(self.current_filenames, self.features, self.batch_size, buffer_size, seed=seed) + def __iter__(self): + return iter(self.iterator) + + def __len__(self): + return len(self.dataset)#self.part_count_local * self.part_size + + +def make_dataloader(cfg, logger, dataset, GPU_batch_size, local_rank, numpy=False): + class BatchCollator(object): + def __init__(self, device=torch.device("cpu")): + self.device = device + self.flip = cfg.DATASET.FLIP_IMAGES + self.numpy = numpy + + def __call__(self, batch): + with torch.no_grad(): + x, = batch + if self.flip: + flips = [(slice(None, None, None), slice(None, None, None), slice(None, None, random.choice([-1, None]))) for _ in range(x.shape[0])] + x = np.array([img[flip] for img, flip in zip(x, flips)]) + if self.numpy: + return x + x = torch.tensor(x, requires_grad=True, device=torch.device(self.device), dtype=torch.float32) + return x + + batches = db.data_loader(iter(dataset), BatchCollator(local_rank), len(dataset) // GPU_batch_size) + + return batches + + +# def make_dataloader_y(cfg, logger, dataset, GPU_batch_size, local_rank): +# class BatchCollator(object): +# def __init__(self, device=torch.device("cpu")): +# self.device = device +# self.flip = cfg.DATASET.FLIP_IMAGES + +# def __call__(self, batch): +# with torch.no_grad(): +# x, y = batch +# if self.flip: +# flips = [(slice(None, None, None), slice(None, None, None), slice(None, None, random.choice([-1, None]))) for _ in range(x.shape[0])] +# x = np.array([img[flip] for img, flip in zip(x, flips)]) +# x = torch.tensor(x, requires_grad=True, device=torch.device(self.device), dtype=torch.float32) +# return x, y + +# batches = db.data_loader(iter(dataset), BatchCollator(local_rank), len(dataset) // GPU_batch_size) + +# return batches + + +# class TFRecordsDatasetImageNet: +# def __init__(self, cfg, logger, rank=0, world_size=1, buffer_size_mb=200, channels=3, seed=None, train=True, needs_labels=False): +# self.cfg = cfg +# self.logger = logger +# self.rank = rank +# self.last_data = "" +# self.part_count = cfg.DATASET.PART_COUNT +# if train: +# self.part_size = cfg.DATASET.SIZE // cfg.DATASET.PART_COUNT +# else: +# self.part_size = cfg.DATASET.SIZE_TEST // cfg.DATASET.PART_COUNT +# self.workers = [] +# self.workers_active = 0 +# self.iterator = None +# self.filenames = {} +# self.batch_size = 512 +# self.features = {} +# self.channels = channels +# self.seed = seed +# self.train = train +# self.needs_labels = needs_labels + +# assert self.part_count % world_size == 0 + +# self.part_count_local = cfg.DATASET.PART_COUNT // world_size + +# if train: +# path = cfg.DATASET.PATH +# else: +# path = cfg.DATASET.PATH_TEST + +# for r in range(2, cfg.DATASET.MAX_RESOLUTION_LEVEL + 1): +# files = [] +# for i in range(self.part_count_local * rank, self.part_count_local * (rank + 1)): +# file = path % (r, i) +# files.append(file) +# self.filenames[r] = files + +# self.buffer_size_b = 1024 ** 2 * buffer_size_mb + +# self.current_filenames = [] + +# def reset(self, lod, batch_size): +# assert lod in self.filenames.keys() +# self.current_filenames = self.filenames[lod] +# self.batch_size = batch_size + +# if self.train: +# img_size = 2 ** lod + 2 ** (lod - 3) +# else: +# img_size = 2 ** lod + +# if self.needs_labels: +# self.features = { +# 'data': db.FixedLenFeature([3, img_size, img_size], db.uint8), +# 'label': db.FixedLenFeature([], db.int64) +# } +# else: +# self.features = { +# 'data': db.FixedLenFeature([3, img_size, img_size], db.uint8) +# } + +# buffer_size = self.buffer_size_b // (self.channels * img_size * img_size) + +# if self.seed is None: +# seed = np.uint64(time.time() * 1000) +# else: +# seed = self.seed +# self.logger.info('!' * 80) +# self.logger.info('! Seed is used for to shuffle data in TFRecordsDataset!') +# self.logger.info('!' * 80) + +# self.iterator = db.ParsedTFRecordsDatasetIterator(self.current_filenames, self.features, self.batch_size, buffer_size, seed=seed) + +# def __iter__(self): +# return self.iterator + +# def __len__(self): +# return self.part_count_local * self.part_size + + +# def make_imagenet_dataloader(cfg, logger, dataset, GPU_batch_size, target_size, local_rank, do_random_crops=True): +# class BatchCollator(object): +# def __init__(self, device=torch.device("cpu")): +# self.device = device +# self.flip = cfg.DATASET.FLIP_IMAGES +# self.size = target_size +# p = math.log2(target_size) +# self.source_size = 2 ** p + 2 ** (p - 3) +# self.do_random_crops = do_random_crops + +# def __call__(self, batch): +# with torch.no_grad(): +# x, = batch + +# if self.do_random_crops: +# images = [] +# for im in x: +# deltax = self.source_size - target_size +# deltay = self.source_size - target_size +# offx = np.random.randint(deltax + 1) +# offy = np.random.randint(deltay + 1) +# im = im[:, offy:offy + self.size, offx:offx + self.size] +# images.append(im) +# x = np.stack(images) + +# if self.flip: +# flips = [(slice(None, None, None), slice(None, None, None), slice(None, None, random.choice([-1, None]))) for _ in range(x.shape[0])] +# x = np.array([img[flip] for img, flip in zip(x, flips)]) +# x = torch.tensor(x, requires_grad=True, device=torch.device(self.device), dtype=torch.float32) + +# return x + +# batches = db.data_loader(iter(dataset), BatchCollator(local_rank), len(dataset) // GPU_batch_size) + +# return batches + + +# def make_imagenet_dataloader_y(cfg, logger, dataset, GPU_batch_size, target_size, local_rank, do_random_crops=True): +# class BatchCollator(object): +# def __init__(self, device=torch.device("cpu")): +# self.device = device +# self.flip = cfg.DATASET.FLIP_IMAGES +# self.size = target_size +# p = math.log2(target_size) +# self.source_size = 2 ** p + 2 ** (p - 3) +# self.do_random_crops = do_random_crops + +# def __call__(self, batch): +# with torch.no_grad(): +# x, y = batch + +# if self.do_random_crops: +# images = [] +# for im in x: +# deltax = self.source_size - target_size +# deltay = self.source_size - target_size +# offx = np.random.randint(deltax + 1) +# offy = np.random.randint(deltay + 1) +# im = im[:, offy:offy+self.size, offx:offx+self.size] +# images.append(im) +# x = np.stack(images) + +# if self.flip: +# flips = [(slice(None, None, None), slice(None, None, None), slice(None, None, random.choice([-1, None]))) for _ in range(x.shape[0])] +# x = np.array([img[flip] for img, flip in zip(x, flips)]) +# x = torch.tensor(x, requires_grad=True, device=torch.device(self.device), dtype=torch.float32) +# return x, y + +# batches = db.data_loader(iter(dataset), BatchCollator(local_rank), len(dataset) // GPU_batch_size) + +# return batches diff --git a/defaults.py b/defaults.py index 41cd317e..612f850f 100644 --- a/defaults.py +++ b/defaults.py @@ -37,8 +37,26 @@ _C.DATASET.MAX_RESOLUTION_LEVEL = 10 +_C.DATASET.SPEC_CHANS=128 +_C.DATASET.TEMPORAL_SAMPLES=128 +_C.DATASET.BCTS = True +_C.DATASET.SUBJECT = [] + _C.MODEL = CN() +_C.MODEL.N_FORMANTS = 4 +_C.MODEL.N_FORMANTS_NOISE = 2 +_C.MODEL.N_FORMANTS_ECOG = 3 +_C.MODEL.WAVE_BASED = False +_C.MODEL.DO_MEL_GUIDE = True +_C.MODEL.BGNOISE_FROMDATA = False +_C.MODEL.N_FFT = 256 +_C.MODEL.NOISE_DB = -50 +_C.MODEL.MAX_DB = 22.5 +_C.MODEL.NOISE_DB_AMP = -25 +_C.MODEL.MAX_DB_AMP = 14 +_C.MODEL.POWER_SYNTH = True + _C.MODEL.LAYER_COUNT = 6 _C.MODEL.START_CHANNEL_COUNT = 64 _C.MODEL.MAX_CHANNEL_COUNT = 512 @@ -53,10 +71,54 @@ _C.MODEL.ENCODER = "EncoderDefault" _C.MODEL.MAPPING_TO_LATENT = "MappingToLatent" _C.MODEL.MAPPING_FROM_LATENT = "MappingFromLatent" +_C.MODEL.MAPPING_FROM_ECOG = "ECoGMappingDefault" +_C.MODEL.ONEDCONFIRST = True +_C.MODEL.RNN_TYPE = 'LSTM' +_C.MODEL.RNN_LAYERS = 4 +_C.MODEL.RNN_COMPUTE_DB_LOUDNESS = True +_C.MODEL.BIDIRECTION = True _C.MODEL.Z_REGRESSION = False +_C.MODEL.AVERAGE_W = False +_C.MODEL.TEMPORAL_W = False +_C.MODEL.GLOBAL_W = True +_C.MODEL.TEMPORAL_GLOBAL_CAT = False +_C.MODEL.RESIDUAL = False +_C.MODEL.W_CLASSIFIER = False +_C.MODEL.UNIQ_WORDS =50 +_C.MODEL.ATTENTION = [] +_C.MODEL.CYCLE = False +_C.MODEL.ATTENTIONAL_STYLE = False +_C.MODEL.HEADS = 1 +_C.MODEL.ECOG=False +_C.MODEL.SUPLOSS_ON_ECOGF=False +_C.MODEL.W_SUP=False +_C.MODEL.APPLY_PPL = False +_C.MODEL.APPLY_PPL_D = False +_C.MODEL.LESS_TEMPORAL_FEATURE = False +_C.MODEL.PPL_WEIGHT = 100 +_C.MODEL.PPL_GLOBAL_WEIGHT = 100 +_C.MODEL.PPLD_WEIGHT = 1 +_C.MODEL.PPLD_GLOBAL_WEIGHT = 1 +_C.MODEL.COMMON_Z = True +_C.MODEL.GAN = True + + +_C.MODEL.TRANSFORMER = CN() +_C.MODEL.TRANSFORMER.HIDDEN_DIM = 256 +_C.MODEL.TRANSFORMER.DIM_FEEDFORWARD = 256 +_C.MODEL.TRANSFORMER.ENCODER_ONLY = True +_C.MODEL.TRANSFORMER.ATTENTIONAL_MASK = False +_C.MODEL.TRANSFORMER.N_HEADS = 1 +_C.MODEL.TRANSFORMER.NON_LOCAL = False + +_C.FINETUNE = CN() +_C.FINETUNE.FINETUNE = False +_C.FINETUNE.ENCODER_GUIDE= False +_C.FINETUNE.FIX_GEN = False +_C.FINETUNE.SPECSUP = True _C.TRAIN = CN() - +_C.TRAIN.PROGRESSIVE = True _C.TRAIN.EPOCHS_PER_LOD = 15 _C.TRAIN.BASE_LEARNING_RATE = 0.0015 @@ -65,12 +127,14 @@ _C.TRAIN.LEARNING_DECAY_RATE = 0.1 _C.TRAIN.LEARNING_DECAY_STEPS = [] _C.TRAIN.TRAIN_EPOCHS = 110 +_C.TRAIN.W_WEIGHT = 5 +_C.TRAIN.CYCLE_WEIGHT = 5 _C.TRAIN.LOD_2_BATCH_8GPU = [512, 256, 128, 64, 32, 32] _C.TRAIN.LOD_2_BATCH_4GPU = [512, 256, 128, 64, 32, 16] _C.TRAIN.LOD_2_BATCH_2GPU = [256, 256, 128, 64, 32, 16] -_C.TRAIN.LOD_2_BATCH_1GPU = [128, 128, 128, 64, 32, 16] - +_C.TRAIN.LOD_2_BATCH_1GPU = [64, 64, 64, 64, 32, 16] +_C.TRAIN.BATCH_SIZE = 4 _C.TRAIN.SNAPSHOT_FREQ = [300, 300, 300, 100, 50, 30, 20, 20, 10] diff --git a/formant_systh.py b/formant_systh.py new file mode 100755 index 00000000..9fc21e21 --- /dev/null +++ b/formant_systh.py @@ -0,0 +1,449 @@ +import pdb +import torch +from torch import nn +# from torch.nn import functional as F +# from registry import * +import lreq as ln +import json +from tqdm import tqdm +import os +import numpy as np +from torch.nn import functional as F +from torchvision.utils import save_image +from torch.nn.parameter import Parameter +from custom_adam import LREQAdam +from ECoGDataSet import ECoGDataset +from net_formant import mel_scale, hz2ind +import matplotlib.pyplot as plt +# from matplotlib.pyplot import ion; ion() +import scipy.signal +import scipy.io.wavfile +import math +from net_formant import amplitude +import torchaudio +def spsi(msgram, fftsize, hop_length) : + """ + Takes a 2D spectrogram ([freqs,frames]), the fft legnth (= widnow length) and the hope size (both in units of samples). + Returns an audio signal. + """ + msgram = np.sqrt(msgram) + numBins, numFrames = msgram.shape + y_out=np.zeros(numFrames*hop_length+fftsize-hop_length) + + m_phase=np.zeros(numBins); + m_win=scipy.signal.hanning(fftsize, sym=True) # assumption here that hann was used to create the frames of the spectrogram + + #processes one frame of audio at a time + for i in range(numFrames) : + m_mag=msgram[:, i] + for j in range(1,numBins-1) : + if(m_mag[j]>m_mag[j-1] and m_mag[j]>m_mag[j+1]) : #if j is a peak + alpha=m_mag[j-1]; + beta=m_mag[j]; + gamma=m_mag[j+1]; + denom=alpha-2*beta+gamma; + + if(denom!=0) : + p=0.5*(alpha-gamma)/denom; + else : + p=0; + + #phaseRate=2*math.pi*(j-1+p)/fftsize; #adjusted phase rate + phaseRate=2*math.pi*(j+p)/fftsize; #adjusted phase rate + m_phase[j]= m_phase[j] + hop_length*phaseRate; #phase accumulator for this peak bin + peakPhase=m_phase[j]; + + # If actual peak is to the right of the bin freq + if (p>0) : + # First bin to right has pi shift + bin=j+1; + m_phase[bin]=peakPhase+math.pi; + + # Bins to left have shift of pi + bin=j-1; + while((bin>1) and (m_mag[bin]1) and (m_mag[bin] 1: mp.spawn(_run, diff --git a/lod_driver.py b/lod_driver.py index dcd7fa08..cfc6957d 100644 --- a/lod_driver.py +++ b/lod_driver.py @@ -20,7 +20,7 @@ class LODDriver: - def __init__(self, cfg, logger, world_size, dataset_size): + def __init__(self, cfg, logger, world_size, dataset_size, progressive=True): if world_size == 8: self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_8GPU if world_size == 4: @@ -29,13 +29,13 @@ def __init__(self, cfg, logger, world_size, dataset_size): self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_2GPU if world_size == 1: self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_1GPU - + self.progressive = progressive self.world_size = world_size self.minibatch_base = 16 self.cfg = cfg self.dataset_size = dataset_size self.current_epoch = 0 - self.lod = -1 + self.lod = -1 if progressive else 5 self.in_transition = False self.logger = logger self.iteration = 0 @@ -99,23 +99,24 @@ def set_epoch(self, epoch, optimizers): self.lod = self.cfg.MODEL.LAYER_COUNT - 1 return - new_lod = min(self.cfg.MODEL.LAYER_COUNT - 1, epoch // self.cfg.TRAIN.EPOCHS_PER_LOD) - if new_lod != self.lod: - self.lod = new_lod - self.logger.info("#" * 80) - self.logger.info("# Switching LOD to %d" % self.lod) - self.logger.info("# Starting transition") - self.logger.info("#" * 80) - self.in_transition = True - for opt in optimizers: - opt.state = defaultdict(dict) + if self.progressive: + new_lod = min(self.cfg.MODEL.LAYER_COUNT - 1, epoch // self.cfg.TRAIN.EPOCHS_PER_LOD) + if new_lod != self.lod: + self.lod = new_lod + self.logger.info("#" * 80) + self.logger.info("# Switching LOD to %d" % self.lod) + self.logger.info("# Starting transition") + self.logger.info("#" * 80) + self.in_transition = True + for opt in optimizers: + opt.state = defaultdict(dict) is_in_first_half_of_cycle = (epoch % self.cfg.TRAIN.EPOCHS_PER_LOD) < (self.cfg.TRAIN.EPOCHS_PER_LOD // 2) is_growing = epoch // self.cfg.TRAIN.EPOCHS_PER_LOD == self.lod > 0 new_in_transition = is_in_first_half_of_cycle and is_growing if new_in_transition != self.in_transition: - self.in_transition = new_in_transition + self.in_transition = new_in_transition if self.progressive else False self.logger.info("#" * 80) self.logger.info("# Transition ended") self.logger.info("#" * 80) diff --git a/losses.py b/losses.py index bd49a667..2f357770 100644 --- a/losses.py +++ b/losses.py @@ -14,6 +14,7 @@ # ============================================================================== import torch +import math import torch.nn.functional as F @@ -28,15 +29,16 @@ def kl(mu, log_var): def reconstruction(recon_x, x, lod=None): return torch.mean((recon_x - x)**2) +def critic_loss(d_result_fake,d_result_real): + loss = (F.softplus(d_result_fake) + F.softplus(-d_result_real)).mean() + return loss -def discriminator_logistic_simple_gp(d_result_fake, d_result_real, reals, r1_gamma=10.0): - loss = (F.softplus(d_result_fake) + F.softplus(-d_result_real)) - +def discriminator_logistic_simple_gp(d_result_real, reals, r1_gamma=10.0): if r1_gamma != 0.0: real_loss = d_result_real.sum() real_grads = torch.autograd.grad(real_loss, reals, create_graph=True, retain_graph=True)[0] r1_penalty = torch.sum(real_grads.pow(2.0), dim=[1, 2, 3]) - loss = loss + r1_penalty * (r1_gamma * 0.5) + loss = r1_penalty * (r1_gamma * 0.5) return loss.mean() @@ -50,3 +52,29 @@ def discriminator_gradient_penalty(d_result_real, reals, r1_gamma=10.0): def generator_logistic_non_saturating(d_result_fake): return F.softplus(-d_result_fake).mean() + + +def pl_lengths_reg(inputs, outputs, mean_path_length, reg_on_gen, temporal_w=False,decay=0.01): + # e.g. for generator, inputs = w (B x 1 x channel x T(optianal)), outputs=images (B x 1 x T x F) + if reg_on_gen: + num_pixels = outputs[0,0,0].numel() if temporal_w else outputs[0,0].numel() # freqbands if temporal else specsize + else: + num_pixels = outputs.shape[2] # latent space size per temporal sample + pl_noise = torch.randn(outputs.shape).cuda() / math.sqrt(num_pixels) + outputs = (outputs * pl_noise).sum() + # if reg_on_gen: + # outputs = (outputs * pl_noise).sum(dim=[0,1,3]) if temporal_w else (outputs * pl_noise).sum() + # else: + # outputs = (outputs * pl_noise).sum(dim=[0,1,2]) if temporal_w else (outputs * pl_noise).sum() + + pl_grads = torch.autograd.grad(outputs=outputs, inputs=inputs, + grad_outputs=torch.ones(outputs.shape).cuda(), + create_graph=True,retain_graph=True)[0] + if reg_on_gen: + path_lengths = ((pl_grads ** 2).sum(dim=2).mean(dim=1)+1e-8).sqrt() #sum over feature, mean over repeated styles for each gen layers + else: + path_lengths = ((pl_grads ** 2).sum(dim=1)+1e-8).sqrt() + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + path_penalty = (path_lengths - path_mean).pow(2).mean() + path_lengths = path_lengths.mean() + return path_penalty,path_mean.detach(),path_lengths \ No newline at end of file diff --git a/lreq.py b/lreq.py index 8ec149a8..4b39fb88 100644 --- a/lreq.py +++ b/lreq.py @@ -48,6 +48,25 @@ def make_tuple(x, n): return x return tuple([x for _ in range(n)]) +def upscale2d(x, factor=2): + s = x.shape + x = torch.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) + x = x.repeat(1, 1, 1, factor, 1, factor) + x = torch.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) + return x + +class Blur(nn.Module): + def __init__(self, channels): + super(Blur, self).__init__() + f = np.array([1, 2, 1], dtype=np.float32) + f = f[:, np.newaxis] * f[np.newaxis, :] + f /= np.sum(f) + kernel = torch.Tensor(f).view(1, 1, 3, 3).repeat(channels, 1, 1, 1) + self.register_buffer('weight', kernel) + self.groups = channels + + def forward(self, x): + return F.conv2d(x, weight=self.weight, groups=self.groups, padding=1) class Linear(nn.Module): def __init__(self, in_features, out_features, bias=True, gain=np.sqrt(2.0), lrmul=1.0, implicit_lreq=use_implicit_lreq): @@ -87,11 +106,10 @@ def forward(self, input): else: return F.linear(input, self.weight, self.bias) - class Conv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True, gain=np.sqrt(2.0), transpose=False, transform_kernel=False, lrmul=1.0, - implicit_lreq=use_implicit_lreq): + implicit_lreq=use_implicit_lreq,initial_weight=None): super(Conv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') @@ -110,6 +128,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.transpose = transpose self.fan_in = np.prod(self.kernel_size) * in_channels // groups self.transform_kernel = transform_kernel + self.initial_weight = initial_weight if transpose: self.weight = Parameter(torch.Tensor(in_channels, out_channels // groups, *self.kernel_size)) else: @@ -123,11 +142,14 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.reset_parameters() def reset_parameters(self): - self.std = self.gain / np.sqrt(self.fan_in) + self.std = self.gain / np.sqrt(self.fan_in) *self.lrmul if not self.implicit_lreq: init.normal_(self.weight, mean=0, std=1.0 / self.lrmul) else: - init.normal_(self.weight, mean=0, std=self.std / self.lrmul) + if self.initial_weight: + self.weight = self.initial_weight + else: + init.normal_(self.weight, mean=0, std=self.std / self.lrmul) setattr(self.weight, 'lr_equalization_coef', self.std) if self.bias is not None: setattr(self.bias, 'lr_equalization_coef', self.lrmul) @@ -168,6 +190,293 @@ def forward(self, x): return F.conv2d(x, w, self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) +class Conv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, + groups=1, bias=True, gain=np.sqrt(2.0), transpose=False, transform_kernel=False, lrmul=1.0, + implicit_lreq=use_implicit_lreq,initial_weight=None): + super(Conv3d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = make_tuple(kernel_size, 3) if isinstance(kernel_size,int) else kernel_size + self.stride = make_tuple(stride, 3) if isinstance(stride,int) else stride + self.padding = make_tuple(padding, 3) if isinstance(padding,int) else padding + self.output_padding = make_tuple(output_padding, 3) if isinstance(output_padding,int) else output_padding + self.dilation = make_tuple(dilation, 3) if isinstance(dilation,int) else dilation + self.groups = groups + self.gain = gain + self.lrmul = lrmul + self.transpose = transpose + self.fan_in = np.prod(self.kernel_size) * in_channels // groups + self.transform_kernel = transform_kernel + self.initial_weight = initial_weight + if transpose: + self.weight = Parameter(torch.Tensor(in_channels, out_channels // groups, *self.kernel_size)) + else: + self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.std = 0 + self.implicit_lreq = implicit_lreq + self.reset_parameters() + + def reset_parameters(self): + self.std = self.gain / np.sqrt(self.fan_in) *self.lrmul + if not self.implicit_lreq: + init.normal_(self.weight, mean=0, std=1.0 / self.lrmul) + else: + if self.initial_weight: + self.weight = self.initial_weight + else: + init.normal_(self.weight, mean=0, std=self.std / self.lrmul) + setattr(self.weight, 'lr_equalization_coef', self.std) + if self.bias is not None: + setattr(self.bias, 'lr_equalization_coef', self.lrmul) + + if self.bias is not None: + with torch.no_grad(): + self.bias.zero_() + + def forward(self, x): + if self.transpose: + w = self.weight + if self.transform_kernel: + w = F.pad(w, (1, 1, 1, 1, 1, 1), mode='constant') + w = w[:, :, 1:, 1:, 1:] + w[:, :, :-1, 1:, 1:] + w[:, :, 1:, :-1, 1:] + w[:, :, :-1, :-1, 1:] + w[:, :, 1:, 1:, :-1] + w[:, :, :-1, 1:, :-1] + w[:, :, 1:, :-1, :-1] + w[:, :, :-1, :-1, :-1] + if not self.implicit_lreq: + bias = self.bias + if bias is not None: + bias = bias * self.lrmul + return F.conv_transpose3d(x, w * self.std, bias, stride=self.stride, + padding=self.padding, output_padding=self.output_padding, + dilation=self.dilation, groups=self.groups) + else: + return F.conv_transpose3d(x, w, self.bias, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, + groups=self.groups) + else: + w = self.weight + if self.transform_kernel: + w = F.pad(w, (1, 1, 1, 1), mode='constant') + w = (w[:, :, 1:, 1:, 1:] + w[:, :, :-1, 1:, 1:] + w[:, :, 1:, :-1, 1:] + w[:, :, :-1, :-1, 1:] + w[:, :, 1:, 1:, :-1] + w[:, :, :-1, 1:, :-1] + w[:, :, 1:, :-1, :-1] + w[:, :, :-1, :-1, :-1]) * 0.125 + if not self.implicit_lreq: + bias = self.bias + if bias is not None: + bias = bias * self.lrmul + return F.conv3d(x, w * self.std, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + else: + return F.conv3d(x, w, self.bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + +class StyleConv2dtest(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, latent_size, stride=1, padding=0, output_padding=0, dilation=1, + groups=1, bias=True, gain=np.sqrt(2.0), transpose=False, transform_kernel=False, lrmul=1.0, + implicit_lreq=False,initial_weight=None,demod=True,upsample=False,temporal_w=False): + super(StyleConv2dtest,self).__init__() + self.demod = demod + self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, + groups=groups, bias=False, gain=gain, transpose=transpose, + transform_kernel=transform_kernel, lrmul=lrmul, + implicit_lreq=implicit_lreq,initial_weight=initial_weight) + self.style = Linear(latent_size, 2*in_channels, gain=1) + if demod: + self.norm = nn.InstanceNorm2d(out_channels, affine=False, eps=1e-8) + self.upsample = upsample + self.transpose = transpose + if bias: + self.bias = Parameter(torch.Tensor(1,out_channels,1,1)) + with torch.no_grad(): + self.bias.zero_() + if upsample: + self.blur = Blur(out_channels) + self.noise_weight = nn.Parameter(torch.zeros(1)) + + def forward(self, x, style,noise=None): + if self.upsample and not self.transpose: + x = upscale2d(x) + w = self.style(style) + w = w.view(w.shape[0], 2, x.shape[1], 1, 1) + x = w[:,1]+x*(w[:,0]+1) + x = F.leaky_relu(self.conv(x),0.2) + if self.demod: + x = self.norm(x) + x = self.bias+x + if self.upsample: + x = self.blur(x) + if noise: + x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight, + tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]])) + + return x + + +class StyleConv2d(Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, latent_size, stride=1, padding=0, output_padding=0, dilation=1, + groups=1, bias=True, gain=np.sqrt(2.0), transpose=False, transform_kernel=False, lrmul=1.0, + implicit_lreq=False,initial_weight=None,demod=True,upsample=False,temporal_w=False): + super(StyleConv2d,self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, + groups=groups, bias=bias, gain=gain, transpose=upsample, + transform_kernel=transform_kernel, lrmul=lrmul, + implicit_lreq=implicit_lreq,initial_weight=initial_weight) + self.demod=demod + self.upsample = upsample + self.transpose = upsample + self.temporal_w = temporal_w + if upsample: + self.blur = Blur(out_channels) + if temporal_w: + self.modulation = Conv1d(latent_size, in_channels,1,1,0, gain=1) + else: + self.modulation = Linear(latent_size, in_channels, gain=1) + self.noise_weight = nn.Parameter(torch.zeros(1)) + + def forward(self, x, style,noise=None): + batch, in_channels, height, width = x.shape + if not self.temporal_w: + assert style.dim()==2, "Style dimension not mach temporal_w condition" + else: + assert style.dim()==3, "Style dimension not mach temporal_w condition" + style = self.modulation(style).view(batch, 1, in_channels, 1, 1) + w = self.weight + w = w if self.implicit_lreq else (w * self.std) + if self.transpose: + w = w.transpose(0,1) # out, in, H, W + if not self.temporal_w: + w2 = w[None, :, :, :, :] # batch, out_chan, in_chan, H, w + w = w2 * (1 + style) + if self.demod: + d = torch.rsqrt((w ** 2).sum(dim=(2, 3, 4), keepdim=True) + 1e-8) + w = w * d + _, _, _, *ws = w.shape + if self.transpose: + w = w.transpose(1,2).reshape(batch* in_channels, self.out_channels, *ws) + else: + w = w.view( batch * self.out_channels, in_channels,*ws) + if self.transform_kernel: + w = F.pad(w, (1, 1, 1, 1), mode='constant') + w = w[..., 1:, 1:] + w[..., :-1, 1:] + w[..., 1:, :-1] + w[..., :-1, :-1] + if not self.transpose: + w =w*0.25 + x = x.view(1, batch * in_channels, height, width) + + bias = self.bias + if not self.implicit_lreq: + if bias is not None: + bias = bias * self.lrmul + if self.transpose: + out = F.conv_transpose2d(x, w, None, stride=self.stride, + padding=self.padding, output_padding=self.output_padding, + dilation=self.dilation, groups=batch) + else: + out = F.conv2d(x, w, None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=batch) + + _, _, height, width = out.shape + out = out.view(batch, self.out_channels, height, width) + if bias is not None: + out = out + bias[None,:,None,None] + if self.upsample: + out = self.blur(out) + + + + else: + assert style.dim()==3, "Style dimension not mach temporal_w condition" + raise ValueError('temporal_w is not support yet') + + if noise: + out = torch.addcmul(out, value=1.0, tensor1=self.noise_weight, + tensor2=torch.randn([out.shape[0], 1, out.shape[2], out.shape[3]])) + return out + +class Conv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, + groups=1, bias=True, gain=np.sqrt(2.0), transpose=False, transform_kernel=False, lrmul=1.0, + implicit_lreq=use_implicit_lreq, bias_initial = 0.): + super(Conv1d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = make_tuple(kernel_size, 1) + self.stride = make_tuple(stride, 1) + self.padding = make_tuple(padding, 1) + self.output_padding = make_tuple(output_padding, 1) + self.bias_initial = bias_initial + self.dilation = make_tuple(dilation, 1) + self.groups = groups + self.gain = gain + self.lrmul = lrmul + self.transpose = transpose + self.fan_in = np.prod(self.kernel_size) * in_channels // groups + self.transform_kernel = transform_kernel + if transpose: + self.weight = Parameter(torch.Tensor(in_channels, out_channels // groups, *self.kernel_size)) + else: + self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.std = 0 + self.implicit_lreq = implicit_lreq + self.reset_parameters() + + def reset_parameters(self): + self.std = self.gain / np.sqrt(self.fan_in) * self.lrmul + if not self.implicit_lreq: + init.normal_(self.weight, mean=0, std=1.0 / self.lrmul) + else: + init.normal_(self.weight, mean=0, std=self.std / self.lrmul) + setattr(self.weight, 'lr_equalization_coef', self.std) + if self.bias is not None: + setattr(self.bias, 'lr_equalization_coef', self.lrmul) + + if self.bias is not None: + with torch.no_grad(): + nn.init.constant_(self.bias,self.bias_initial) + + def forward(self, x): + if self.transpose: + w = self.weight + if self.transform_kernel: + w = F.pad(w, (1, 1), mode='constant') + w = w[:, :, 1:] + w[:, :, :-1] + if not self.implicit_lreq: + bias = self.bias + if bias is not None: + bias = bias * self.lrmul + return F.conv_transpose1d(x, w * self.std, bias, stride=self.stride, + padding=self.padding, output_padding=self.output_padding, + dilation=self.dilation, groups=self.groups) + else: + return F.conv_transpose1d(x, w, self.bias, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, + groups=self.groups) + else: + w = self.weight + if self.transform_kernel: + w = F.pad(w, (1, 1), mode='constant') + w = (w[:, :, 1:] + w[:, :, :-1]) * 0.5 + if not self.implicit_lreq: + bias = self.bias + if bias is not None: + bias = bias * self.lrmul + return F.conv1d(x, w * self.std, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + else: + return F.conv1d(x, w, self.bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) class ConvTranspose2d(Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, @@ -187,6 +496,23 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, lrmul=lrmul, implicit_lreq=implicit_lreq) +class ConvTranspose1d(Conv1d): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, + groups=1, bias=True, gain=np.sqrt(2.0), transform_kernel=False, lrmul=1.0, implicit_lreq=use_implicit_lreq): + super(ConvTranspose1d, self).__init__(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + bias=bias, + gain=gain, + transpose=True, + transform_kernel=transform_kernel, + lrmul=lrmul, + implicit_lreq=implicit_lreq) class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, @@ -199,6 +525,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, def forward(self, x): return self.channel_conv(self.spatial_conv(x)) +class SeparableConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, + bias=True, gain=np.sqrt(2.0), transpose=False): + super(SeparableConv1d, self).__init__() + self.spatial_conv = Conv1d(in_channels, in_channels, kernel_size, stride, padding, output_padding, dilation, + in_channels, False, 1, transpose) + self.channel_conv = Conv1d(in_channels, out_channels, 1, bias, 1, gain=gain) + + def forward(self, x): + return self.channel_conv(self.spatial_conv(x)) class SeparableConvTranspose2d(Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, @@ -206,3 +542,10 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, super(SeparableConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, output_padding, dilation, bias, gain, True) +class SeparableConvTranspose1d(Conv1d): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, + bias=True, gain=np.sqrt(2.0)): + super(SeparableConvTranspose1d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, + output_padding, dilation, bias, gain, True) + + diff --git a/make_figures/make_recon_figure_interpolation.py b/make_figures/make_recon_figure_interpolation.py index f42b4c40..680bb1aa 100644 --- a/make_figures/make_recon_figure_interpolation.py +++ b/make_figures/make_recon_figure_interpolation.py @@ -22,6 +22,7 @@ from dlutils.pytorch import count_parameters from defaults import get_cfg_defaults import lreq +import os from PIL import Image @@ -58,12 +59,38 @@ def sample(cfg, logger): layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, - truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, - truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, + dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, + style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, - encoder=cfg.MODEL.ENCODER) + encoder=cfg.MODEL.ENCODER, + ecog_encoder=cfg.MODEL.MAPPING_FROM_ECOG, + z_regression=cfg.MODEL.Z_REGRESSION, + average_w = cfg.MODEL.AVERAGE_W, + temporal_w = cfg.MODEL.TEMPORAL_W, + global_w = cfg.MODEL.GLOBAL_W, + temporal_global_cat = cfg.MODEL.TEMPORAL_GLOBAL_CAT, + spec_chans = cfg.DATASET.SPEC_CHANS, + temporal_samples = cfg.DATASET.TEMPORAL_SAMPLES, + init_zeros = cfg.MODEL.TEMPORAL_W, + residual = cfg.MODEL.RESIDUAL, + w_classifier = cfg.MODEL.W_CLASSIFIER, + uniq_words = cfg.MODEL.UNIQ_WORDS, + attention = cfg.MODEL.ATTENTION, + cycle = cfg.MODEL.CYCLE, + w_weight = cfg.TRAIN.W_WEIGHT, + cycle_weight=cfg.TRAIN.CYCLE_WEIGHT, + attentional_style=cfg.MODEL.ATTENTIONAL_STYLE, + heads = cfg.MODEL.HEADS, + suploss_on_ecog = cfg.MODEL.SUPLOSS_ON_ECOGF, + less_temporal_feature = cfg.MODEL.LESS_TEMPORAL_FEATURE, + ppl_weight=cfg.MODEL.PPL_WEIGHT, + ppl_global_weight=cfg.MODEL.PPL_GLOBAL_WEIGHT, + ppld_weight=cfg.MODEL.PPLD_WEIGHT, + ppld_global_weight=cfg.MODEL.PPLD_GLOBAL_WEIGHT, + common_z = cfg.MODEL.COMMON_Z, + ) model.cuda(0) model.eval() model.requires_grad_(False) @@ -73,7 +100,6 @@ def sample(cfg, logger): mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg - logger.info("Trainable parameters generator:") count_parameters(decoder) @@ -97,54 +123,83 @@ def sample(cfg, logger): logger=logger, save=False) - extra_checkpoint_data = checkpointer.load() + extra_checkpoint_data = checkpointer.load(file_name='./training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_noprogressive_HBw_ppl_ppld_localreg_debug/model_tmp_lod6.pth') + # extra_checkpoint_data = checkpointer.load(file_name='./training_artifacts/ecog_residual_cycle_attention3264wIN_specchan64_more_attentfeatures/model_tmp_lod4.pth') model.eval() layer_count = cfg.MODEL.LAYER_COUNT - def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + if cfg.MODEL.TEMPORAL_W and cfg.MODEL.GLOBAL_W: + Z = (Z[0].repeat(1, model.mapping_fl.num_layers, 1, 1),Z[1].repeat(1, model.mapping_fl.num_layers, 1)) + else: + if cfg.MODEL.TEMPORAL_W: + Z = Z.repeat(1, model.mapping_fl.num_layers, 1, 1) + else: + Z = Z.repeat(1, model.mapping_fl.num_layers, 1) return Z def decode(x): - layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis] - ones = torch.ones(layer_idx.shape, dtype=torch.float32) - coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones) + # layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis] + # ones = torch.ones(layer_idx.shape, dtype=torch.float32) + # coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones) # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs) return model.decoder(x, layer_count - 1, 1, noise=True) - rnd = np.random.RandomState(4) latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE) path = cfg.DATASET.SAMPLES_PATH im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1) - pathA = '00001.png' - pathB = '00022.png' - pathC = '00077.png' - pathD = '00016.png' - + # pathA = 'kite.npy' + # pathB = 'cat.npy' + # pathC = 'hat.npy' + # pathD = 'cake.npy' + pathA = 'vase.npy' + pathB = 'cow.npy' + pathC = 'hat.npy' + pathD = 'cake.npy' + + + # def open_image(filename): + # img = np.asarray(Image.open(path + '/' + filename)) + # if img.shape[2] == 4: + # img = img[:, :, :3] + # im = img.transpose((2, 0, 1)) + # x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. + # if x.shape[0] == 4: + # x = x[:3] + # factor = x.shape[2] // im_size + # if factor != 1: + # x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] + # assert x.shape[2] == im_size + # _latents = encode(x[None, ...].cuda()) + # latents = _latents[0, 0] + # return latents def open_image(filename): - img = np.asarray(Image.open(path + '/' + filename)) - if img.shape[2] == 4: - img = img[:, :, :3] - im = img.transpose((2, 0, 1)) - x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. - if x.shape[0] == 4: - x = x[:3] - factor = x.shape[2] // im_size + im = np.load(os.path.join(path, filename)) + x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() + factor = x.shape[1] // im_size if factor != 1: x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] - assert x.shape[2] == im_size + assert x.shape[1] == im_size _latents = encode(x[None, ...].cuda()) - latents = _latents[0, 0] + if cfg.MODEL.TEMPORAL_W and cfg.MODEL.GLOBAL_W: + latents = (_latents[0][0,0],_latents[1][0,0]) + else: + latents = _latents[0, 0] return latents def make(w): with torch.no_grad(): - w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) + if cfg.MODEL.TEMPORAL_W and cfg.MODEL.GLOBAL_W: + w = (w[0][None, None, ...].repeat(1, model.mapping_fl.num_layers, 1, 1),w[1][None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)) + else: + if cfg.MODEL.TEMPORAL_W: + w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1, 1) + else: + w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) x_rec = decode(w) return x_rec @@ -152,9 +207,9 @@ def make(w): wb = open_image(pathB) wc = open_image(pathC) wd = open_image(pathD) - - height = 7 - width = 7 + import pdb;pdb.set_trace() + height = 10 + width = 10 images = [] @@ -168,18 +223,22 @@ def make(w): kc = (1.0 - kh) * kv kd = kh * kv - w = ka * wa + kb * wb + kc * wc + kd * wd + if cfg.MODEL.TEMPORAL_W and cfg.MODEL.GLOBAL_W: + w = ((1-kh) * wa[0] + kh * wb[0] , (1-kv) * wa[1] + kv * wb[1]) + else: + w = ka * wa + kb * wb + kc * wc + kd * wd interpolated = make(w) images.append(interpolated) images = torch.cat(images) - - save_image(images * 0.5 + 0.5, 'make_figures/output/%s/interpolations.png' % cfg.NAME, nrow=width) - save_image(images * 0.5 + 0.5, 'make_figures/output/%s/interpolations.jpg' % cfg.NAME, nrow=width) + images = images.permute(0,1,3,2) + os.makedirs('make_figures/output/%s' % cfg.NAME, exist_ok=True) + save_image(images * 0.5 + 0.5, 'make_figures/output/%s/interpolations_vase_cow.png' % cfg.NAME, nrow=width) + save_image(images * 0.5 + 0.5, 'make_figures/output/%s/interpolations_vase_cow.jpg' % cfg.NAME, nrow=width) if __name__ == "__main__": gpu_count = 1 - run(sample, get_cfg_defaults(), description='ALAE-interpolations', default_config='configs/ffhq.yaml', + run(sample, get_cfg_defaults(), description='ALAE-interpolations', default_config='configs/ecog_style2.yaml', world_size=gpu_count, write_log=False) diff --git a/model.py b/model.py index 78c1da9a..e32b2f65 100644 --- a/model.py +++ b/model.py @@ -20,146 +20,434 @@ class DLatent(nn.Module): - def __init__(self, dlatent_size, layer_count): + def __init__(self, dlatent_size, layer_count,temporal_w=False,temporal_samples=128): super(DLatent, self).__init__() - buffer = torch.zeros(layer_count, dlatent_size, dtype=torch.float32) + if temporal_w: + buffer = torch.zeros(layer_count, dlatent_size, temporal_samples, dtype=torch.float32) + else: + buffer = torch.zeros(layer_count, dlatent_size, dtype=torch.float32) + self.register_buffer('buff', buffer) + +class PPL_MEAN(nn.Module): + def __init__(self): + super(PPL_MEAN, self).__init__() + buffer = torch.zeros(1, dtype=torch.float32) self.register_buffer('buff', buffer) class Model(nn.Module): - def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_layers=5, dlatent_avg_beta=None, - truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3, generator="", - encoder="", z_regression=False): + def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, uniq_words=50, mapping_layers=5, dlatent_avg_beta=None, + truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3, generator="", encoder="", + z_regression=False,average_w=False,spec_chans = 128,temporal_samples=128,temporal_w=False, global_w=True,temporal_global_cat = False,init_zeros=False, + residual=False,w_classifier=False,attention=None,cycle=None,w_weight=1.0,cycle_weight=1.0, attentional_style=False,heads=1, + ppl_weight=100,ppl_global_weight=100,ppld_weight=1,ppld_global_weight=1,common_z = False, + with_ecog = False, ecog_encoder="",suploss_on_ecog=False,less_temporal_feature=False): super(Model, self).__init__() self.layer_count = layer_count self.z_regression = z_regression - + self.common_z = common_z + self.temporal_w = temporal_w + self.global_w = global_w + self.temporal_global_cat = temporal_global_cat + self.w_classifier = w_classifier + self.cycle = cycle + self.w_weight=w_weight + self.cycle_weight=cycle_weight + self.ppl_weight = ppl_weight + self.ppl_global_weight = ppl_global_weight + self.ppld_weight = ppld_weight + self.ppld_global_weight = ppld_global_weight + self.suploss_on_ecog = suploss_on_ecog + self.with_ecog = with_ecog + latent_feature = latent_size//4 if (temporal_w and less_temporal_feature) else latent_size self.mapping_tl = MAPPINGS["MappingToLatent"]( - latent_size=latent_size, + latent_size=latent_feature, dlatent_size=latent_size, mapping_fmaps=latent_size, - mapping_layers=3) + mapping_layers=5 if temporal_w else 3, + temporal_w = temporal_w, + global_w = global_w) + + self.mapping_tw = MAPPINGS["MappingToWord"]( + latent_size=latent_feature, + uniq_words=uniq_words, + mapping_fmaps=latent_size, + mapping_layers=1, + temporal_w = temporal_w) self.mapping_fl = MAPPINGS["MappingFromLatent"]( num_layers=2 * layer_count, - latent_size=latent_size, + latent_size=latent_feature, dlatent_size=latent_size, mapping_fmaps=latent_size, - mapping_layers=mapping_layers) + mapping_layers=mapping_layers, + temporal_w = temporal_w, + global_w = global_w) self.decoder = GENERATORS[generator]( startf=startf, layer_count=layer_count, maxf=maxf, - latent_size=latent_size, - channels=channels) + latent_size=latent_feature, + channels=channels, + spec_chans=spec_chans, temporal_samples = temporal_samples, + temporal_w = temporal_w, + global_w = global_w, + temporal_global_cat = temporal_global_cat, + init_zeros = init_zeros, + residual = residual, + attention=attention, + attentional_style=attentional_style, + heads = heads, + ) self.encoder = ENCODERS[encoder]( startf=startf, layer_count=layer_count, maxf=maxf, - latent_size=latent_size, - channels=channels) - - self.dlatent_avg = DLatent(latent_size, self.mapping_fl.num_layers) + latent_size=latent_feature, + channels=channels, + spec_chans=spec_chans, temporal_samples = temporal_samples, + average_w=average_w, + temporal_w = temporal_w, + global_w = global_w, + temporal_global_cat = temporal_global_cat, + residual = residual, + attention=attention, + attentional_style=attentional_style, + heads = heads, + ) + + if with_ecog: + self.ecog_encoder = ECOG_ENCODER[ecog_encoder]( + latent_size = latent_feature, + average_w = average_w, + temporal_w=temporal_w, + global_w = global_w, + attention=attention, + temporal_samples=temporal_samples, + attentional_style=attentional_style, + heads=heads, + ) + + self.dlatent_avg = DLatent(latent_feature, self.mapping_fl.num_layers,temporal_w=temporal_w) + self.ppl_mean = PPL_MEAN() + self.ppl_d_mean = PPL_MEAN() + if temporal_w and global_w: + self.dlatent_avg_global = DLatent(latent_feature, self.mapping_fl.num_layers,temporal_w=False) + self.ppl_mean_global = PPL_MEAN() + self.ppl_d_mean_global = PPL_MEAN() self.latent_size = latent_size self.dlatent_avg_beta = dlatent_avg_beta self.truncation_psi = truncation_psi self.style_mixing_prob = style_mixing_prob self.truncation_cutoff = truncation_cutoff - def generate(self, lod, blend_factor, z=None, count=32, mixing=True, noise=True, return_styles=False, no_truncation=False): + def generate(self, lod, blend_factor, z=None, z_global=None, count=32, mixing=True, noise=True, return_styles=False, no_truncation=False,ecog_only=True,ecog=None,mask_prior=None): if z is None: z = torch.randn(count, self.latent_size) - styles = self.mapping_fl(z)[:, 0] - s = styles.view(styles.shape[0], 1, styles.shape[1]) - - styles = s.repeat(1, self.mapping_fl.num_layers, 1) - - if self.dlatent_avg_beta is not None: - with torch.no_grad(): - batch_avg = styles.mean(dim=0) - self.dlatent_avg.buff.data.lerp_(batch_avg.data, 1.0 - self.dlatent_avg_beta) - - if mixing and self.style_mixing_prob is not None: - if random.random() < self.style_mixing_prob: - z2 = torch.randn(count, self.latent_size) - styles2 = self.mapping_fl(z2)[:, 0] - styles2 = styles2.view(styles2.shape[0], 1, styles2.shape[1]).repeat(1, self.mapping_fl.num_layers, 1) - - layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] - cur_layers = (lod + 1) * 2 - mixing_cutoff = random.randint(1, cur_layers) - styles = torch.where(layer_idx < mixing_cutoff, styles, styles2) - - if (self.truncation_psi is not None) and not no_truncation: - layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] - ones = torch.ones(layer_idx.shape, dtype=torch.float32) - coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) - styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) + if z_global is None: + z_global = z if self.common_z else torch.randn(count, self.latent_size) + if ecog is not None: + styles_ecog = self.ecog_encoder(ecog,mask_prior) + if self.temporal_w and self.global_w: + styles_ecog, styles_ecog_global = styles_ecog + s_ecog = styles_ecog.view(styles_ecog.shape[0], 1, styles_ecog.shape[1],styles_ecog.shape[2]) + styles_ecog = s_ecog.repeat(1, self.mapping_fl.num_layers, 1,1) + s_ecog_global = styles_ecog_global.view(styles_ecog_global.shape[0], 1, styles_ecog_global.shape[1]) + styles_ecog_global = s_ecog_global.repeat(1, self.mapping_fl.num_layers, 1) + else: + if self.temporal_w: + s_ecog = styles_ecog.view(styles_ecog.shape[0], 1, styles_ecog.shape[1],styles_ecog.shape[2]) + styles_ecog = s_ecog.repeat(1, self.mapping_fl.num_layers, 1,1) + if self.global_w: + s_ecog = styles_ecog.view(styles_ecog.shape[0], 1, styles_ecog.shape[1]) + styles_ecog = s_ecog.repeat(1, self.mapping_fl.num_layers, 1) + if self.w_classifier: + Z__ = self.mapping_tw(styles_ecog, styles_ecog_global) + + if (ecog is None) or (not ecog_only): + if (self.temporal_w and self.global_w): + styles = self.mapping_fl(z,z_global) + styles, styles_global = styles + styles = styles[:,0] + styles_global = styles_global[:,0] + else: + styles = self.mapping_fl(z)[:, 0] + if self.temporal_w and self.global_w: + s = styles.view(styles.shape[0], 1, styles.shape[1],styles.shape[2]) + styles = s.repeat(1, self.mapping_fl.num_layers, 1,1) + s_global = styles_global.view(styles_global.shape[0], 1, styles_global.shape[1]) + styles_global = s_global.repeat(1, self.mapping_fl.num_layers, 1) + else: + if self.temporal_w: + s = styles.view(styles.shape[0], 1, styles.shape[1],styles.shape[2]) + styles = s.repeat(1, self.mapping_fl.num_layers, 1,1) + if self.global_w: + s = styles.view(styles.shape[0], 1, styles.shape[1]) + styles = s.repeat(1, self.mapping_fl.num_layers, 1) + + if self.dlatent_avg_beta is not None: + with torch.no_grad(): + batch_avg = styles.mean(dim=0) + self.dlatent_avg.buff.data.lerp_(batch_avg.data, 1.0 - self.dlatent_avg_beta) + if self.temporal_w and self.global_w: + batch_avg = styles_global.mean(dim=0) + self.dlatent_avg_global.buff.data.lerp_(batch_avg.data, 1.0 - self.dlatent_avg_beta) + + if mixing and self.style_mixing_prob is not None: + if random.random() < self.style_mixing_prob: + cur_layers = (lod + 1) * 2 + mixing_cutoff = random.randint(1, cur_layers) + layer_idx = torch.arange(self.mapping_fl.num_layers) + z2 = torch.randn(count, self.latent_size) + z2_global = z2 if self.common_z else torch.randn(count, self.latent_size) + if (self.temporal_w and self.global_w): + styles2 = self.mapping_fl(z2,z2_global) + styles2, styles2_global = styles2 + styles2 = styles2[:,0] + styles2_global = styles2_global[:,0] + else: + styles2 = self.mapping_fl(z2)[:, 0] + if self.temporal_w and self.global_w: + styles2 = styles2.view(styles2.shape[0], 1, styles2.shape[1],styles2.shape[2]).repeat(1, self.mapping_fl.num_layers, 1,1) + styles = torch.where(layer_idx[np.newaxis, :, np.newaxis,np.newaxis] < mixing_cutoff, styles, styles2) + styles2_global = styles2_global.view(styles2_global.shape[0], 1, styles2_global.shape[1]).repeat(1, self.mapping_fl.num_layers, 1) + styles_global = torch.where(layer_idx[np.newaxis, :, np.newaxis] < mixing_cutoff, styles_global, styles2_global) + else: + if self.temporal_w: + styles2 = styles2.view(styles2.shape[0], 1, styles2.shape[1],styles2.shape[2]).repeat(1, self.mapping_fl.num_layers, 1,1) + styles = torch.where(layer_idx[np.newaxis, :, np.newaxis,np.newaxis] < mixing_cutoff, styles, styles2) + if self.global_w: + styles2 = styles2.view(styles2.shape[0], 1, styles2.shape[1]).repeat(1, self.mapping_fl.num_layers, 1) + styles = torch.where(layer_idx[np.newaxis, :, np.newaxis] < mixing_cutoff, styles, styles2) + + if (self.truncation_psi is not None) and not no_truncation: + if self.temporal_w and self.global_w: + layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis,np.newaxis] + ones = torch.ones(layer_idx.shape, dtype=torch.float32) + coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) + styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) + layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] + ones = torch.ones(layer_idx.shape, dtype=torch.float32) + coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) + styles_global = torch.lerp(self.dlatent_avg_global.buff.data, styles_global, coefs) + else: + if self.temporal_w: + layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis,np.newaxis] + ones = torch.ones(layer_idx.shape, dtype=torch.float32) + coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) + styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) + if self.global_w: + layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] + ones = torch.ones(layer_idx.shape, dtype=torch.float32) + coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) + styles = torch.lerp(self.dlatent_avg_global.buff.data, styles, coefs) + + # import pdb; pdb.set_trace() + if ecog is not None: + if (not ecog_only): + styles = torch.cat([styles_ecog,styles],dim=0) + s = torch.cat([s_ecog,s],dim=0) + if self.temporal_w and self.global_w: + styles_global = torch.cat([styles_ecog_global,styles_global],dim=0) + s_global = torch.cat([s_ecog_global,s_global],dim=0) + else: + styles = styles_ecog + s = s_ecog + if self.temporal_w and self.global_w: + styles_global = styles_ecog_global + s_global = s_ecog_global + + if self.temporal_w and self.global_w: + styles = (styles,styles_global) rec = self.decoder.forward(styles, lod, blend_factor, noise) - if return_styles: - return s, rec + # import pdb; pdb.set_trace() + if self.w_classifier: + if return_styles: + if self.temporal_w and self.global_w: + return (s, s_global), rec, Z__ + else: + return s, rec, Z__ + else: + return rec,Z__ else: - return rec + if return_styles: + if self.temporal_w and self.global_w: + return (s, s_global), rec + else: + return s, rec + else: + return rec - def encode(self, x, lod, blend_factor): + def encode(self, x, lod, blend_factor,word_classify=False): Z = self.encoder(x, lod, blend_factor) - Z_ = self.mapping_tl(Z) - return Z[:, :1], Z_[:, 1, 0] + if self.temporal_w and self.global_w: + Z,Z_global = Z + + Z_ = self.mapping_tl(Z[:,0],Z_global[:,0]) if (self.temporal_w and self.global_w) else self.mapping_tl(Z[:,0]) + if word_classify: + Z__ = self.mapping_tw(Z[:,0],Z_global[:,0]) if (self.temporal_w and self.global_w) else self.mapping_tw(Z[:,0]) + if self.temporal_w and self.global_w: + return (Z[:, :1],Z_global[:,:1]), Z_[:, 1, 0], Z__ + else: + return Z[:, :1], Z_[:, 1, 0], Z__ + else: + if self.temporal_w and self.global_w: + return (Z[:, :1],Z_global[:,:1]), Z_[:, 1, 0] + else: + return Z[:, :1], Z_[:, 1, 0] - def forward(self, x, lod, blend_factor, d_train, ae): + def forward(self, x, lod, blend_factor, d_train, ae, tracker,words=None,apply_encoder_guide=False,apply_w_classifier=False,apply_cycle=True,apply_gp=True,apply_ppl=True,apply_ppl_d=False,ecog=None,sup=True,mask_prior=None,gan=True): if ae: self.encoder.requires_grad_(True) z = torch.randn(x.shape[0], self.latent_size) - s, rec = self.generate(lod, blend_factor, z=z, mixing=False, noise=True, return_styles=True) - - Z, d_result_real = self.encode(rec, lod, blend_factor) - - assert Z.shape == s.shape + if self.temporal_w and self.global_w: + z_global = z if self.common_z else torch.randn(x.shape[0], self.latent_size) + else: + z_global = None + s, rec = self.generate(lod, blend_factor, z=z, z_global=z_global, mixing=False, noise=True, return_styles=True,ecog=ecog,mask_prior=mask_prior) + + Z, _ = self.encode(rec, lod, blend_factor) + do_cycle = self.cycle and apply_cycle + if do_cycle: + Z_real, _ = self.encode(x, lod, blend_factor) + if self.temporal_w and self.global_w: + Z_real,Z_real_global = Z_real + Z_real_global = Z_real_global.repeat(1, self.mapping_fl.num_layers, 1) + Z_real = Z_real.repeat(1, self.mapping_fl.num_layers, 1) + rec = self.decoder.forward((Z_real,Z_real_global) if (self.temporal_w and self.global_w) else Z_real, lod, blend_factor, noise=True) + Lcycle = self.cycle_weight*torch.mean((rec - x).abs()) + tracker.update(dict(Lcycle=Lcycle)) + else: + Lcycle=0 + + # assert Z.shape == s.shape if self.z_regression: - Lae = torch.mean(((Z[:, 0] - z)**2)) + Lae = self.w_weight*torch.mean(((Z[:, 0] - z)**2)) else: - Lae = torch.mean(((Z - s.detach())**2)) - - return Lae + if self.temporal_w and self.global_w: + Z,Z_global = Z + s,s_global = s + Lae = self.w_weight*(torch.mean(((Z - s.detach())**2)) + torch.mean(((Z_global - s_global.detach())**2))) + else: + Lae = self.w_weight*torch.mean(((Z - s.detach())**2)) + tracker.update(dict(Lae=Lae)) + + return Lae+Lcycle elif d_train: with torch.no_grad(): - Xp = self.generate(lod, blend_factor, count=x.shape[0], noise=True) + Xp = self.generate(lod, blend_factor, count=x.shape[0], noise=True,ecog=ecog,mask_prior=mask_prior) self.encoder.requires_grad_(True) + + if apply_w_classifier: + _, d_result_real, word_logits = self.encode(x, lod, blend_factor,word_classify=True) + else: + xs = torch.cat([x,Xp.requires_grad_(True)],dim=0) + w, d_result = self.encode(xs, lod, blend_factor) + if self.temporal_w and self.global_w: + w, w_global = w + w_real_global = w_global[:w_global.shape[0]//2] + w_fake_global = w_global[w_global.shape[0]//2:] + w_real = w[:w.shape[0]//2] + w_fake = w[w.shape[0]//2:] + d_result_real = d_result[:d_result.shape[0]//2] + d_result_fake = d_result[d_result.shape[0]//2:] + # w_real, d_result_real = self.encode(x, lod, blend_factor) + # w_fake, d_result_fake = self.encode(Xp.requires_grad_(True), lod, blend_factor) + + loss_d = losses.critic_loss(d_result_fake, d_result_real) + tracker.update(dict(loss_d=loss_d)) + if apply_gp: + loss_gp = losses.discriminator_logistic_simple_gp(d_result_real, x) + loss_d += loss_gp + else: + loss_gp=0 + if apply_ppl_d: + path_loss_d, self.ppl_d_mean.buff.data, path_lengths_d = losses.pl_lengths_reg(xs, w, self.ppl_d_mean.buff.data,reg_on_gen=False,temporal_w = self.temporal_w) + path_loss_d =path_loss_d*self.ppld_weight + tracker.update(dict(path_loss_d=path_loss_d,path_lengths_d=path_lengths_d)) + if self.temporal_w and self.global_w and self.ppld_global_weight != 0: + path_loss_d_global, self.ppl_d_mean_global.buff.data, path_lengths_d_global = losses.pl_lengths_reg(xs, w_global, self.ppl_d_mean_global.buff.data,reg_on_gen=False,temporal_w = False) + path_loss_d_global = path_loss_d_global*self.ppld_global_weight + tracker.update(dict(path_loss_d_global=path_loss_d_global,path_lengths_d_global=path_lengths_d_global)) + path_loss_d = path_loss_d+path_loss_d_global + # path_loss_d =path_loss_d*self.ppl_weight + # path_loss, self.ppl_mean.buff.data, path_lengths = losses.pl_lengths_reg(torch.cat([x,Xp],dim=0), torch.cat([w_real,w_fake],dim=0), self.ppl_mean.buff.data ) + else: + path_loss_d=0 + if apply_w_classifier: + loss_word = F.cross_entropy(word_logits,words) + tracker.update(dict(loss_word=loss_word)) + else: + loss_word=0 + return loss_d+loss_word+path_loss_d - _, d_result_real = self.encode(x, lod, blend_factor) - - _, d_result_fake = self.encode(Xp.detach(), lod, blend_factor) - - loss_d = losses.discriminator_logistic_simple_gp(d_result_fake, d_result_real, x) - return loss_d else: with torch.no_grad(): z = torch.randn(x.shape[0], self.latent_size) + if self.temporal_w and self.global_w: + z_global = z if self.common_z else torch.randn(x.shape[0], self.latent_size) + else: + z_global = None self.encoder.requires_grad_(False) + s, rec = self.generate(lod, blend_factor, count=x.shape[0], z=z.detach(), z_global=z_global, noise=True,return_styles=True,ecog=ecog,mask_prior=mask_prior) + if self.temporal_w and self.global_w: + s,s_global = s - rec = self.generate(lod, blend_factor, count=x.shape[0], z=z.detach(), noise=True) - - _, d_result_fake = self.encode(rec, lod, blend_factor) + if gan: + _, d_result_fake = self.encode(rec, lod, blend_factor) - loss_g = losses.generator_logistic_non_saturating(d_result_fake) - - return loss_g + loss_g = losses.generator_logistic_non_saturating(d_result_fake) + tracker.update(dict(loss_g=loss_g)) + else: + loss_g = 0 + + if apply_encoder_guide: + Z_real, _ = self.encode(x, lod, blend_factor) + if self.temporal_w and self.global_w: + Z_real,Z_real_global = Z_real + loss_w_sup = self.w_weight*(torch.mean(((Z_real - s)**2))+torch.mean(((Z_real_global - s_global)**2))) + else: + loss_w_sup = self.w_weight*torch.mean(((Z_real - s)**2)) + tracker.update(dict(loss_w_sup=loss_w_sup)) + else: + loss_w_sup=0 + + if apply_ppl: + path_loss, self.ppl_mean.buff.data, path_lengths = losses.pl_lengths_reg(s, rec, self.ppl_mean.buff.data,reg_on_gen=True,temporal_w = self.temporal_w) + path_loss =path_loss*self.ppl_weight + tracker.update(dict(path_loss=path_loss, path_lengths=path_lengths)) + if self.temporal_w and self.global_w: + path_loss_global, self.ppl_mean_global.buff.data, path_lengths_global = losses.pl_lengths_reg(s_global, rec, self.ppl_mean_global.buff.data,reg_on_gen=True,temporal_w = False) + path_loss_global =path_loss_global*self.ppl_global_weight + tracker.update(dict(path_loss_global=path_loss_global, path_lengths_global=path_lengths_global)) + path_loss = path_loss+path_loss_global + else: + path_loss = 0 + if ecog is not None and sup: + loss_sup = torch.mean((rec - x).abs()) + tracker.update(dict(loss_sup=loss_sup)) + else: + loss_sup = 0 + if ecog is not None and self.suploss_on_ecog: + return loss_g+ path_loss, loss_sup + loss_w_sup + else: + return loss_g+ path_loss+ loss_sup + loss_w_sup - def lerp(self, other, betta): + def lerp(self, other, betta,w_classifier=False): if hasattr(other, 'module'): other = other.module with torch.no_grad(): - params = list(self.mapping_tl.parameters()) + list(self.mapping_fl.parameters()) + list(self.decoder.parameters()) + list(self.encoder.parameters()) + list(self.dlatent_avg.parameters()) - other_param = list(other.mapping_tl.parameters()) + list(other.mapping_fl.parameters()) + list(other.decoder.parameters()) + list(other.encoder.parameters()) + list(other.dlatent_avg.parameters()) + params = list(self.mapping_tl.parameters())+ list(self.mapping_fl.parameters()) + list(self.decoder.parameters()) + list(self.encoder.parameters()) + list(self.dlatent_avg.parameters()) + (list(other.dlatent_avg_global.parameters()) if (self.temporal_w and self.global_w) else []) + (list(self.mapping_tw.parameters()) if w_classifier else []) + (list(self.ecog_encoder.parameters()) if self.with_ecog else []) + other_param = list(other.mapping_tl.parameters()) + list(other.mapping_fl.parameters()) + list(other.decoder.parameters()) + list(other.encoder.parameters()) + list(other.dlatent_avg.parameters()) + (list(other.dlatent_avg_global.parameters()) if (self.temporal_w and self.global_w) else []) + (list(other.mapping_tw.parameters()) if w_classifier else []) + (list(other.ecog_encoder.parameters()) if self.with_ecog else []) for p, p_other in zip(params, other_param): p.data.lerp_(p_other.data, 1.0 - betta) @@ -185,7 +473,7 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_ latent_size=latent_size, channels=channels) - self.dlatent_avg = DLatent(latent_size, self.mapping_fl.num_layers) + self.dlatent_avg = DLatent(latent_size, self.mapping_fl.num_layers,temporal_w=temporal_w) self.latent_size = latent_size self.dlatent_avg_beta = dlatent_avg_beta self.truncation_psi = truncation_psi diff --git a/model_formant.py b/model_formant.py new file mode 100644 index 00000000..188a1a20 --- /dev/null +++ b/model_formant.py @@ -0,0 +1,1061 @@ +import pdb +import random +from tracker import LossTracker +import losses +from net_formant import * +import numpy as np +from torch.nn import functional as F +def compdiff(comp): + return ((comp[:,:,1:]-comp[:,:,:-1]).abs()).mean() + +def compdiffd2(comp): + diff = comp[:,:,1:]-comp[:,:,:-1] + return ((diff[:,:,1:]-diff[:,:,:-1]).abs()).mean() + +def _expand_binary_labels(labels, label_weights, label_channels): + bin_labels = labels.new_full((labels.size(0), label_channels), 0) + inds = torch.nonzero(labels >= 1).squeeze() + if inds.numel() > 0: + bin_labels[inds, labels[inds] - 1] = 1 + bin_label_weights = label_weights.view(-1, 1).expand( + label_weights.size(0), label_channels) + return bin_labels, bin_label_weights + +class GHMC(nn.Module): + def __init__( + self, + bins=30, + momentum=0, + use_sigmoid=True, + loss_weight=1.0): + super(GHMC, self).__init__() + self.bins = bins + self.momentum = momentum + self.edges = [float(x) / bins for x in range(bins+1)] + self.edges[-1] += 1e-6 + if momentum > 0: + self.acc_sum = [0.0 for _ in range(bins)] + self.use_sigmoid = use_sigmoid + self.loss_weight = loss_weight + + def forward(self, pred, target, label_weight, *args, **kwargs): + """ Args: + pred [batch_num, class_num]: + The direct prediction of classification fc layer. + target [batch_num, class_num]: + Binary class target for each sample. + label_weight [batch_num, class_num]: + the value is 1 if the sample is valid and 0 if ignored. + """ + if not self.use_sigmoid: + raise NotImplementedError + # the target should be binary class label + if pred.dim() != target.dim(): + target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1)) + target, label_weight = target.float(), label_weight.float() + edges = self.edges + mmt = self.momentum + weights = torch.zeros_like(pred) + + # gradient length + g = torch.abs(pred.sigmoid().detach() - target) + + valid = label_weight > 0 + tot = max(valid.float().sum().item(), 1.0) + n = 0 # n valid bins + for i in range(self.bins): + inds = (g >= edges[i]) & (g < edges[i+1]) & valid + num_in_bin = inds.sum().item() + if num_in_bin > 0: + if mmt > 0: + self.acc_sum[i] = mmt * self.acc_sum[i] \ + + (1 - mmt) * num_in_bin + weights[inds] = tot / self.acc_sum[i] + else: + weights[inds] = tot / num_in_bin + n += 1 + if n > 0: + weights = weights / n + + loss = F.binary_cross_entropy_with_logits( + pred, target, weights, reduction='sum') / tot + return loss * self.loss_weight + +class GHMR(nn.Module): + def __init__( + self, + mu=0.02, + bins=30, + momentum=0, + loss_weight=1.0): + super(GHMR, self).__init__() + self.mu = mu + self.bins = bins + self.edges = [float(x) / bins for x in range(bins+1)] + self.edges[-1] = 1e3 + self.momentum = momentum + if momentum > 0: + self.acc_sum = [0.0 for _ in range(bins)] + self.loss_weight = loss_weight + + def forward(self, pred, target, label_weight, avg_factor=None,reweight=1): + """ Args: + pred [batch_num, 4 (* class_num)]: + The prediction of box regression layer. Channel number can be 4 or + (4 * class_num) depending on whether it is class-agnostic. + target [batch_num, 4 (* class_num)]: + The target regression values with the same size of pred. + label_weight [batch_num, 4 (* class_num)]: + The weight of each sample, 0 if ignored. + """ + mu = self.mu + edges = self.edges + mmt = self.momentum + + # ASL1 loss + diff = pred - target + loss = torch.sqrt(diff * diff + mu * mu) - mu + + # gradient length + g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach() + weights = torch.zeros_like(g) + + valid = label_weight > 0 + tot = max(label_weight.float().sum().item(), 1.0) + n = 0 # n: valid bins + for i in range(self.bins): + inds = (g >= edges[i]) & (g < edges[i+1]) & valid + num_in_bin = inds.sum().item() + if num_in_bin > 0: + n += 1 + if mmt > 0: + self.acc_sum[i] = mmt * self.acc_sum[i] \ + + (1 - mmt) * num_in_bin + weights[inds] = tot / self.acc_sum[i] + else: + weights[inds] = tot / num_in_bin + if n > 0: + weights /= n + + loss = loss * weights + loss = (loss*reweight).sum() / tot + return loss * self.loss_weight + +class LAE(nn.Module): + def __init__(self,mu=0.02, + bins=30, + momentum=0.75, + loss_weight=1.0,db=True,amp=True,noise_db=-50,max_db=22.5): + super(LAE, self).__init__() + self.db=db + self.amp = amp + self.noise_db = noise_db + self.max_db = max_db + if db: + self.ghm_db = GHMR(mu,bins,momentum,loss_weight) + if amp: + self.ghm_amp = GHMR(mu,bins,momentum,loss_weight) + + def forward(self, rec, spec, tracker=None,reweight=1): + if self.db: + loss_db = self.ghm_db(rec,spec,torch.ones(spec.shape),reweight=reweight) + if tracker is not None: + tracker.update(dict(Lae_db=loss_db)) + else: + loss_db = torch.tensor(0.0) + if self.amp: + spec_amp = amplitude(spec,noise_db=self.noise_db,max_db=self.max_db) + rec_amp = amplitude(rec,noise_db=self.noise_db,max_db=self.max_db) + loss_a = self.ghm_amp(rec_amp,spec_amp,torch.ones(spec_amp.shape),reweight=reweight) + if tracker is not None: + tracker.update(dict(Lae_a=loss_a)) + else: + loss_a = torch.tensor(0.0) + return loss_db+loss_a + + +class Model(nn.Module): + def __init__(self, generator="", encoder="", ecog_encoder_name="", + spec_chans = 128, n_formants=2, n_formants_noise=2, n_formants_ecog=2, n_fft=256, noise_db=-50, max_db=22.5, wavebased = False, + with_ecog = False, ghm_loss=True,power_synth=True,apply_flooding=True,ecog_compute_db_loudness=False, + hidden_dim=256,dim_feedforward=256,encoder_only=True,attentional_mask=False,n_heads=1,non_local=False,do_mel_guide = True,noise_from_data=False,specsup=True,\ + onedconfirst=True,rnn_type = 'LSTM',rnn_layers = 4,compute_db_loudness=True,bidirection = True): + super(Model, self).__init__() + self.spec_chans = spec_chans + self.with_ecog = with_ecog + self.ecog_encoder_name = ecog_encoder_name + self.n_formants_ecog = n_formants_ecog + self.wavebased = wavebased + self.n_fft = n_fft + self.n_mels = spec_chans + self.do_mel_guide = do_mel_guide + self.noise_db = noise_db + self.spec_sup = specsup + self.max_db = max_db + self.apply_flooding = apply_flooding + self.n_formants_noise = n_formants_noise + self.power_synth =power_synth + self.decoder = GENERATORS[generator]( + n_mels = spec_chans, + k = 40, + wavebased = wavebased, + n_fft = n_fft, + noise_db = noise_db, + max_db = max_db, + noise_from_data = noise_from_data, + return_wave = False, + power_synth=power_synth, + ) + if do_mel_guide: + self.decoder_mel = GENERATORS[generator]( + n_mels = spec_chans, + k = 40, + wavebased = False, + n_fft = n_fft, + noise_db = noise_db, + max_db = max_db, + add_bgnoise = False, + ) + self.encoder = ENCODERS[encoder]( + n_mels = spec_chans, + n_formants = n_formants, + n_formants_noise = n_formants_noise, + wavebased = wavebased, + hop_length = 128, + n_fft = n_fft, + noise_db = noise_db, + max_db = max_db, + power_synth = power_synth, + ) + if with_ecog: + if 'Transformer' in ecog_encoder_name: + self.ecog_encoder = ECOG_ENCODER[ecog_encoder_name]( + n_mels = spec_chans,n_formants = n_formants_ecog, + hidden_dim=hidden_dim,dim_feedforward=dim_feedforward,n_heads=n_heads, + encoder_only=encoder_only,attentional_mask=attentional_mask,non_local=non_local, + compute_db_loudness = ecog_compute_db_loudness, + ) + else: + self.ecog_encoder = ECOG_ENCODER[ecog_encoder_name]( + n_mels = spec_chans,n_formants = n_formants_ecog, + compute_db_loudness = ecog_compute_db_loudness, + ) + self.ghm_loss = ghm_loss + self.lae1 = LAE(noise_db=self.noise_db,max_db=self.max_db) + self.lae2 = LAE(amp=False) + self.lae3 = LAE(amp=False) + self.lae4 = LAE(amp=False) + self.lae5 = LAE(amp=False) + self.lae6 = LAE(amp=False) + self.lae7 = LAE(amp=False) + self.lae8 = LAE(amp=False) + + def noise_dist_init(self,dist): + with torch.no_grad(): + self.decoder.noise_dist = dist.reshape([1,1,1,dist.shape[0]]) + + def generate_fromecog(self, ecog = None, mask_prior = None, mni=None,return_components=False): + components = self.ecog_encoder(ecog, mask_prior,mni) + rec = self.decoder.forward(components) + if return_components: + return rec, components + else: + return rec + + def generate_fromspec(self, spec, return_components=False,x_denoise=None,duomask=False):#,gender='Female'): + components = self.encoder(spec,x_denoise=x_denoise,duomask=duomask)#,gender=gender) + rec = self.decoder.forward(components) + if return_components: + return rec, components + else: + return rec + + def encode(self, spec,x_denoise=None,duomask=False,noise_level = None,x_amp=None):#,gender='Female'): + components = self.encoder(spec,x_denoise=x_denoise,duomask=duomask,noise_level=noise_level,x_amp=x_amp)#,gender=gender) + return components + + def lae(self,spec,rec,db=True,amp=True,tracker=None,GHM=False): + if amp: + spec_amp = amplitude(spec,noise_db=self.noise_db,max_db=self.max_db) + rec_amp = amplitude(rec,noise_db=self.noise_db,max_db=self.max_db) + # if self.power_synth: + # spec_amp_ = spec_amp**0.5 + # rec_amp_ = rec_amp**0.5 + # else: + # spec_amp_ = spec_amp + # rec_amp_ = rec_amp + spec_amp_ = spec_amp + rec_amp_ = rec_amp + if GHM: + Lae_a = self.ghm_loss(rec_amp_,spec_amp_,torch.ones(spec_amp_))#*150 + Lae_a_l2 = torch.tensor([0.]) + else: + Lae_a = (spec_amp_-rec_amp_).abs().mean()#*150 + Lae_a_l2 = torch.sqrt((spec_amp_-rec_amp_)**2+1E-6).mean()#*150 + else: + Lae_a = torch.tensor(0.) + Lae_a_l2 = torch.tensor(0.) + if tracker is not None: + tracker.update(dict(Lae_a=Lae_a,Lae_a_l2=Lae_a_l2)) + if db: + if GHM: + Lae_db = self.ghm_loss(rec,spec,torch.ones(spec))#*150 + Lae_db_l2 = torch.tensor([0.]) + else: + Lae_db = (spec-rec).abs().mean() + Lae_db_l2 = torch.sqrt((spec-rec)**2+1E-6).mean() + else: + Lae_db = torch.tensor(0.) + Lae_db_l2 = torch.tensor(0.) + if tracker is not None: + tracker.update(dict(Lae_db=Lae_db,Lae_db_l2=Lae_db_l2)) + # return (Lae_a + Lae_a_l2)/2. + (Lae_db+Lae_db_l2)/2. + return Lae_a + Lae_db/2. + + def flooding(self, loss,beta): + if self.apply_flooding: + return (loss-beta).abs()+beta + else: + return loss + + def forward(self, spec, ecog, mask_prior, on_stage, on_stage_wider, ae, tracker, encoder_guide, x_mel=None,x_denoise=None, pitch_aug=False, duomask=False, mni=None,debug=False,x_amp=None,hamonic_bias=False,x_amp_from_denoise=False,gender='Female'): + if ae: + self.encoder.requires_grad_(True) + # rec = self.generate_fromspec(spec) + components = self.encoder(spec,x_denoise = x_denoise,duomask=duomask,noise_level = F.softplus(self.decoder.bgnoise_amp)*self.decoder.noise_dist.mean(),x_amp=x_amp)#,gender=gender) + rec = self.decoder.forward(components) + + freq_cord = torch.arange(self.spec_chans).reshape([1,1,1,self.spec_chans])/(1.0*self.spec_chans) + freq_cord2 = torch.arange(self.spec_chans+1).reshape([1,1,1,self.spec_chans+1])/(1.0*self.spec_chans) + freq_linear_reweighting = 1 if self.wavebased else (inverse_mel_scale(freq_cord2[...,1:])-inverse_mel_scale(freq_cord2[...,:-1]))/440*7 + # freq_linear_reweighting = 1 + # Lae = 4*self.lae((rec*freq_linear_reweighting)[...,128:],(spec*freq_linear_reweighting)[...,128:],tracker=tracker)#torch.mean((rec - spec).abs()*freq_linear_reweighting) + Lae = 4*self.lae(rec*freq_linear_reweighting,spec*freq_linear_reweighting,tracker=tracker)#torch.mean((rec - spec).abs()*freq_linear_reweighting) + if self.wavebased: + spec_amp = amplitude(spec,self.noise_db,self.max_db).transpose(-2,-1) + rec_amp = amplitude(rec,self.noise_db,self.max_db).transpose(-2,-1) + freq_cord2 = torch.arange(128+1).reshape([1,1,1,128+1])/(1.0*128) + freq_linear_reweighting2 = (inverse_mel_scale(freq_cord2[...,1:])-inverse_mel_scale(freq_cord2[...,:-1]))/440*7 + spec_mel = to_db(torchaudio.transforms.MelScale(f_max=8000,n_stft=self.n_fft)(spec_amp).transpose(-2,-1),self.noise_db,self.max_db) + rec_mel = to_db(torchaudio.transforms.MelScale(f_max=8000,n_stft=self.n_fft)(rec_amp).transpose(-2,-1),self.noise_db,self.max_db) + Lae += 4*self.lae(rec_mel*freq_linear_reweighting2,spec_mel*freq_linear_reweighting2,tracker=tracker) + + # hann_win = torch.hann_window(21,periodic=False).reshape([1,1,21,1]) + # spec_broud = to_db(F.conv2d(spec_amp,hann_win,padding=[10,0]).transpose(-2,-1),self.noise_db,self.max_db) + # rec_broud = to_db(F.conv2d(rec_amp,hann_win,padding=[10,0]).transpose(-2,-1),self.noise_db,self.max_db) + # Lae += 4*self.lae(rec_broud,spec_broud,tracker=tracker) + + if self.do_mel_guide: + rec_mel = self.decoder_mel.forward(components) + freq_linear_reweighting_mel = (inverse_mel_scale(freq_cord2[...,1:])-inverse_mel_scale(freq_cord2[...,:-1]))/440*7 + Lae_mel = 4*self.lae(rec_mel*freq_linear_reweighting_mel,x_mel*freq_linear_reweighting_mel,tracker=None) + tracker.update(dict(Lae_mel=Lae_mel)) + Lae+=Lae_mel + + # rec_denoise = self.decoder.forward(components,enable_hamon_excitation=False,enable_noise_excitation=False) + # Lae_noise = (50. if self.wavebased else 50.)*self.lae(rec_noise*(1-on_stage_wider.unsqueeze(-1)),spec*(1-on_stage_wider.unsqueeze(-1))) + # tracker.update(dict(Lae_noise=Lae_noise)) + # Lae += Lae_noise + + if x_amp_from_denoise: + if self.wavebased: + if self.power_synth: + Lloudness = 10**6*(components['loudness']*(1-on_stage_wider)).mean() + else: + # Lloudness = 10**3*(components['loudness']*(1-on_stage_wider)).mean() + Lloudness = 10**6*(components['loudness']*(1-on_stage_wider)).mean() + # Lloudness = 10.**6*((components['loudness'])**2*(1-on_stage_wider)).mean() + tracker.update(dict(Lloudness=Lloudness)) + Lae += Lloudness + + if self.wavebased and x_denoise is not None: + thres = int(hz2ind(4000,self.n_fft)) if self.wavebased else mel_scale(self.spec_chans,4000,pt=False).astype(np.int32) + explosive=(torch.mean((spec*freq_linear_reweighting)[...,thres:],dim=-1)>torch.mean((spec*freq_linear_reweighting)[...,:thres],dim=-1)).to(torch.float32).unsqueeze(-1) + rec_denoise = self.decoder.forward(components,enable_hamon_excitation=True,enable_noise_excitation=True,enable_bgnoise=False) + Lae_denoise = 20*self.lae(rec_denoise*freq_linear_reweighting*explosive,x_denoise*freq_linear_reweighting*explosive) + tracker.update(dict(Lae_denoise=Lae_denoise)) + Lae += Lae_denoise + # import pdb;pdb.set_trace() + # if components['freq_formants_hamon'].shape[1] > 2: + freq_limit = self.encoder.formant_freq_limits_abs.squeeze() + from net_formant import mel_scale + freq_limit = hz2ind(freq_limit,self.n_fft).long() if self.wavebased else mel_scale(self.spec_chans,freq_limit).long() + if debug: + import pdb;pdb.set_trace() + + + # if True: + if not self.wavebased: + n_formant_noise = components['freq_formants_noise'].shape[1]-components['freq_formants_hamon'].shape[1] + for formant in range(components['freq_formants_hamon'].shape[1]-1,1,-1): + components_copy = {i:j.clone() for i,j in components.items()} + components_copy['freq_formants_hamon'] = components_copy['freq_formants_hamon'][:,:formant] + components_copy['freq_formants_hamon_hz'] = components_copy['freq_formants_hamon_hz'][:,:formant] + components_copy['bandwidth_formants_hamon'] = components_copy['bandwidth_formants_hamon'][:,:formant] + components_copy['bandwidth_formants_hamon_hz'] = components_copy['bandwidth_formants_hamon_hz'][:,:formant] + components_copy['amplitude_formants_hamon'] = components_copy['amplitude_formants_hamon'][:,:formant] + + if duomask: + # components_copy['freq_formants_noise'] = components_copy['freq_formants_noise'][:,:formant] + # components_copy['freq_formants_noise_hz'] = components_copy['freq_formants_noise_hz'][:,:formant] + # components_copy['bandwidth_formants_noise'] = components_copy['bandwidth_formants_noise'][:,:formant] + # components_copy['bandwidth_formants_noise_hz'] = components_copy['bandwidth_formants_noise_hz'][:,:formant] + # components_copy['amplitude_formants_noise'] = components_copy['amplitude_formants_noise'][:,:formant] + components_copy['freq_formants_noise'] = torch.cat([components_copy['freq_formants_noise'][:,:formant],components_copy['freq_formants_noise'][:,-n_formant_noise:]],dim=1) + components_copy['freq_formants_noise_hz'] = torch.cat([components_copy['freq_formants_noise_hz'][:,:formant],components_copy['freq_formants_noise_hz'][:,-n_formant_noise:]],dim=1) + components_copy['bandwidth_formants_noise'] = torch.cat([components_copy['bandwidth_formants_noise'][:,:formant],components_copy['bandwidth_formants_noise'][:,-n_formant_noise:]],dim=1) + components_copy['bandwidth_formants_noise_hz'] = torch.cat([components_copy['bandwidth_formants_noise_hz'][:,:formant],components_copy['bandwidth_formants_noise_hz'][:,-n_formant_noise:]],dim=1) + components_copy['amplitude_formants_noise'] = torch.cat([components_copy['amplitude_formants_noise'][:,:formant],components_copy['amplitude_formants_noise'][:,-n_formant_noise:]],dim=1) + # rec = self.decoder.forward(components_copy,enable_noise_excitation=True) + # Lae += self.lae(rec,spec,tracker=tracker)#torch.mean((rec - spec).abs()) + rec = self.decoder.forward(components_copy,enable_noise_excitation=True if self.wavebased else True) + Lae += 1*self.lae((rec*freq_linear_reweighting),(spec*freq_linear_reweighting),tracker=tracker)#torch.mean(((rec - spec).abs()*freq_linear_reweighting)[...,:freq_limit[formant-1]]) + # Lae += self.lae((rec*freq_linear_reweighting)[...,:freq_limit[formant-1]],(spec*freq_linear_reweighting)[...,:freq_limit[formant-1]],tracker=tracker)#torch.mean(((rec - spec).abs()*freq_linear_reweighting)[...,:freq_limit[formant-1]]) + # Lamp = 1*torch.mean(F.relu(-components['amplitude_formants_hamon'][:,0:3]+components['amplitude_formants_hamon'][:,1:4])*(components['amplitudes'][:,0:1]>components['amplitudes'][:,1:2]).float()) + # tracker.update(dict(Lamp=Lamp)) + # Lae+=Lamp + else: + Lamp = 10*torch.mean(F.relu(-components['amplitude_formants_hamon'][:,0:3]+components['amplitude_formants_hamon'][:,1:4])*(components['amplitudes'][:,0:1]>components['amplitudes'][:,1:2]).float()) + tracker.update(dict(Lamp=Lamp)) + Lae+=Lamp + tracker.update(dict(Lae=Lae)) + if debug: + import pdb;pdb.set_trace() + + + thres = int(hz2ind(4000,self.n_fft)) if self.wavebased else mel_scale(self.spec_chans,4000,pt=False).astype(np.int32) + explosive=torch.sign(torch.mean((spec*freq_linear_reweighting)[...,thres:],dim=-1)-torch.mean((spec*freq_linear_reweighting)[...,:thres],dim=-1))*0.5+0.5 + Lexp = torch.mean((components['amplitudes'][:,0:1]-components['amplitudes'][:,1:2])*explosive)*100 + tracker.update(dict(Lexp=Lexp)) + Lae += Lexp + + if hamonic_bias: + hamonic_loss = 1000*torch.mean((1-components['amplitudes'][:,0])*on_stage) + Lae += hamonic_loss + + # alphaloss=(F.relu(0.5-(components['amplitudes']-0.5).abs())*100).mean() + # Lae+=alphaloss + + if pitch_aug: + pitch_shift = (2**(-1.5+3*torch.rand([components['f0_hz'].shape[0]]).to(torch.float32)).reshape([components['f0_hz'].shape[0],1,1])) # +- 1 octave + # pitch_shift = (2**(torch.randint(-1,2,[components['f0_hz'].shape[0]]).to(torch.float32)).reshape([components['f0_hz'].shape[0],1,1])).clamp(min=88,max=616) # +- 1 octave + components['f0_hz'] = (components['f0_hz']*pitch_shift).clamp(min=88,max=300) + # components['f0'] = mel_scale(self.spec_chans,components['f0'])/self.spec_chans + rec_shift = self.decoder.forward(components) + components_enc = self.encoder(rec_shift,duomask=duomask,x_denoise=x_denoise,noise_level = F.softplus(self.decoder.bgnoise_amp)*self.decoder.noise_dist.mean(),x_amp=x_amp)#,gender=gender) + Lf0 = torch.mean((components_enc['f0_hz']/200-components['f0_hz']/200)**2) + rec_cycle = self.decoder.forward(components_enc) + Lae += self.lae(rec_shift*freq_linear_reweighting,rec_cycle*freq_linear_reweighting,tracker=tracker)#torch.mean((rec_shift-rec_cycle).abs()*freq_linear_reweighting) + # import pdb;pdb.set_trace() + else: + # Lf0 = torch.mean((F.relu(160 - components['f0_hz']) + F.relu(components['f0_hz']-420))/10) + Lf0 = torch.tensor([0.]) + # Lf0 = torch.tensor([0.]) + tracker.update(dict(Lf0=Lf0)) + + spec = spec.squeeze(dim=1).permute(0,2,1) #B * f * T + loudness = torch.mean(spec*0.5+0.5,dim=1,keepdim=True) + # import pdb;pdb.set_trace() + if self.wavebased: + # hamonic_components_diff = compdiffd2(components['f0_hz']*2) + compdiff(components['amplitudes'])*750.# + compdiff(components['amplitude_formants_hamon'])*750. + ((components['loudness']*components['amplitudes'][:,1:]/0.0001)**0.125).mean()*50 + compdiff(components['amplitude_formants_noise'])*750. + if self.power_synth: + hamonic_components_diff = compdiffd2(components['freq_formants_hamon_hz']*1.5) + compdiffd2(components['f0_hz']*2) + compdiff(components['bandwidth_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/5) + compdiff(components['freq_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/5)+ compdiff(components['amplitudes'])*750. + compdiffd2(components['amplitude_formants_hamon'])*1500.+ compdiffd2(components['amplitude_formants_noise'])*1500.# + ((components['loudness']*components['amplitudes'][:,1:]/0.0001)**0.125).mean()*50 + else: + # hamonic_components_diff = compdiffd2(components['freq_formants_hamon_hz']*1.5) + compdiffd2(components['f0_hz']*2) + compdiff(components['amplitudes'])*750. + compdiffd2(components['amplitude_formants_hamon'])*1500.+ compdiffd2(components['amplitude_formants_noise'])*1500.# + ((components['loudness']*components['amplitudes'][:,1:]/0.0001)**0.125).mean()*50 + hamonic_components_diff = compdiffd2(components['freq_formants_hamon_hz']*1.5) + compdiffd2(components['f0_hz']*2) + compdiff(components['bandwidth_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/5) + compdiff(components['freq_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/5)+ compdiff(components['amplitudes'])*750. + compdiffd2(components['amplitude_formants_hamon'])*1500.+ compdiffd2(components['amplitude_formants_noise'])*1500.# + ((components['loudness']*components['amplitudes'][:,1:]/0.0001)**0.125).mean()*50 + + # hamonic_components_diff = compdiffd2(components['freq_formants_hamon_hz']*1.5) + compdiffd2(components['f0_hz']*2) + compdiff(components['bandwidth_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/5) + compdiff(components['freq_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/5)+ compdiff(components['amplitudes'])*750.# + compdiff(components['amplitude_formants_hamon'])*750. + ((components['loudness']*components['amplitudes'][:,1:]/0.0001)**0.125).mean()*50 + compdiff(components['amplitude_formants_noise'])*750. + # hamonic_components_diff = compdiffd2(components['freq_formants_hamon_hz']*1.5) + compdiffd2(components['f0_hz']*8) + compdiff(components['bandwidth_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/5) + compdiff(components['freq_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/5)+ compdiff(components['amplitudes'])*750.# + compdiff(components['amplitude_formants_hamon'])*750. + ((components['loudness']*components['amplitudes'][:,1:]/0.0001)**0.125).mean()*50 + compdiff(components['amplitude_formants_noise'])*750. + # hamonic_components_diff = compdiffd2(components['freq_formants_hamon_hz']*2) + compdiffd2(components['f0_hz']/10) + compdiff(components['amplitude_formants_hamon'])*750. + compdiff(components['amplitude_formants_noise'])*750. + compdiffd2(components['freq_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/10) + compdiff(components['bandwidth_formants_noise_hz'][:,components['freq_formants_hamon_hz'].shape[1]:]/10) + # hamonic_components_diff = compdiffd2(components['freq_formants_hamon_hz'])+100*compdiffd2(components['f0_hz']*3) + compdiff(components['amplitude_formants_hamon'])*750. + compdiff(components['amplitude_formants_noise'])*750. #+ compdiff(components['freq_formants_noise_hz']*(1-on_stage_wider)) + else: + hamonic_components_diff = compdiff(components['freq_formants_hamon_hz']*(1-on_stage_wider))+compdiff(components['f0_hz']*(1-on_stage_wider)) + compdiff(components['amplitude_formants_hamon'])*750. + compdiff(components['amplitude_formants_noise'])*750. #+ compdiff(components['freq_formants_noise_hz']*(1-on_stage_wider)) + # hamonic_components_diff = compdiff(components['freq_formants_hamon_hz'])+compdiff(components['f0_hz']) + compdiff(components['amplitude_formants_hamon']*(1-on_stage_wider))*1500. + compdiff(components['amplitude_formants_noise']*(1-on_stage_wider))*1500. + compdiff(components['freq_formants_noise_hz']) + Ldiff = torch.mean(hamonic_components_diff)/2000. + # Ldiff = torch.mean(components['freq_formants_hamon'].var()+components['freq_formants_noise'].var())*10 + tracker.update(dict(Ldiff=Ldiff)) + Lae += Ldiff + # Ldiff = 0 + Lfreqorder = torch.mean(F.relu(components['freq_formants_hamon_hz'][:,:-1]-components['freq_formants_hamon_hz'][:,1:])) #+ (torch.mean(F.relu(components['freq_formants_noise_hz'][:,:-1]-components['freq_formants_noise_hz'][:,1:])) if components['freq_formants_noise_hz'].shape[1]>1 else 0) + + return Lae + Lf0 + Lfreqorder,tracker + else: #ecog to audio + self.encoder.requires_grad_(False) + rec,components_ecog = self.generate_fromecog(ecog,mask_prior,mni=mni,return_components=True) + + ###### mel db flooding + betas = {'loudness':0.01,'freq_formants_hamon':0.0025,'f0_hz':0.,'amplitudes':0.,'amplitude_formants_hamon':0.,'amplitude_formants_noise':0.,'freq_formants_noise':0.05,'bandwidth_formants_noise_hz':0.01} + alpha = {'loudness':1.,'freq_formants_hamon':4.,'f0_hz':1.,'amplitudes':1.,'amplitude_formants_hamon':1.,'amplitude_formants_noise':1.,'freq_formants_noise':1.,'bandwidth_formants_noise_hz':1.} + if self.spec_sup: + if False:#self.ghm_loss: + Lrec = 0.3*self.lae1(rec,spec,tracker=tracker) + else: + Lrec = self.lae(rec,spec,tracker=tracker)#torch.mean((rec - spec)**2) + # Lamp = 10*torch.mean(F.relu(-components_ecog['amplitude_formants_hamon'][:,0:min(3,self.n_formants_ecog-1)]+components_ecog['amplitude_formants_hamon'][:,1:min(4,self.n_formants_ecog)])*(components_ecog['amplitudes'][:,0:1]>components_ecog['amplitudes'][:,1:2]).float()) + # tracker.update(dict(Lamp=Lamp)) + # Lrec+=Lamp + else: + Lrec = torch.tensor([0.0])# + # Lrec = torch.mean((rec - spec).abs()) + tracker.update(dict(Lrec=Lrec)) + Lcomp = 0 + if encoder_guide: + components_guide = self.encode(spec,x_denoise=x_denoise,duomask=duomask,noise_level = F.softplus(self.decoder.bgnoise_amp)*self.decoder.noise_dist.mean(),x_amp=x_amp)#,gender=gender) + consonant_weight = 1#100*(torch.sign(components_guide['amplitudes'][:,1:]-0.5)*0.5+0.5) + if self.power_synth: + loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + else: + loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # loudness_db_norm = (loudness_db.clamp(min=-35)+35)/25 + loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']**2) + #loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + for key in ['loudness','f0_hz','amplitudes','amplitude_formants_hamon','freq_formants_hamon','amplitude_formants_noise','freq_formants_noise','bandwidth_formants_noise_hz']: + # if 'hz' in key: + # continue + if key == 'loudness': + if self.power_synth: + loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+70)/50 + else: + loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+70)/50 + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+35)/25 + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key]**2)+70)/50 + if False:#self.ghm_loss: + diff = self.lae2(loudness_db_norm, loudness_db_norm_ecog) + else: + diff = alpha['loudness']*15*torch.mean((loudness_db_norm - loudness_db_norm_ecog)**2)#+ torch.mean((components_guide[key] - components_ecog[key])**2 * on_stage * consonant_weight) + diff = self.flooding(diff,alpha['loudness']*betas['loudness']) + tracker.update({'loudness_metric' : torch.mean((loudness_db_norm - loudness_db_norm_ecog)**2*on_stage_wider)}) + + if key == 'f0_hz': + # diff = torch.mean((components_guide[key]*6 - components_ecog[key]*6)**2 * on_stage_wider * components_guide['loudness']/4) + diff = alpha['f0_hz']*0.3*torch.mean((components_guide[key]/200*5 - components_ecog[key]/200*5)**2 * on_stage_wider * loudness_db_norm) + diff = self.flooding(diff,alpha['f0_hz']*betas['f0_hz']) + tracker.update({'f0_metric' : torch.mean((components_guide['f0_hz']/200*5 - components_ecog['f0_hz']/200*5)**2 * on_stage_wider * loudness_db_norm)}) + + if key in ['amplitudes']: + # if key in ['amplitudes','amplitudes_h']: + weight = on_stage_wider * loudness_db_norm + if self.ghm_loss: + # diff = 100*self.lae3(components_guide[key], components_ecog[key],reweight=weight) + diff = alpha['amplitudes']*30*self.lae3(components_guide[key], components_ecog[key],reweight=weight) + else: + diff = alpha['amplitudes']*10*torch.mean((components_guide[key] - components_ecog[key])**2 *weight) + diff = self.flooding(diff,alpha['amplitudes']*betas['amplitudes']) + tracker.update({'amplitudes_metric' : torch.mean((components_guide['amplitudes'] - components_ecog['amplitudes'])**2 *weight)}) + + if key in ['amplitude_formants_hamon']: + weight = components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + if False:#self.ghm_loss: + diff = 40*self.lae4(components_guide[key][:,:self.n_formants_ecog], components_ecog[key],reweight=weight) + # diff = 10*self.lae4(components_guide[key][:,:self.n_formants_ecog], components_ecog[key],reweight=weight) + else: + # diff = 100*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # diff = 40*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight)/2 \ + # + 40*torch.mean((torchaudio.transforms.AmplitudeToDB()(components_guide[key][:,:self.n_formants_ecog])/100 - torchaudio.transforms.AmplitudeToDB()(components_ecog[key])/100)**2 * weight)/2 + diff = alpha['amplitude_formants_hamon']*40*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # diff = 10*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # if key in ['freq_formants_hamon']: + # diff = torch.mean((components_guide[key][:,:1]*10 - components_ecog[key][:,:1]*10)**2 * components_guide['amplitude_formants_hamon'][:,:1] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm ) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]*10 - components_ecog[key]*10)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + diff = self.flooding(diff,alpha['amplitude_formants_hamon']*betas['amplitude_formants_hamon']) + tracker.update({'amplitude_formants_hamon_metric' : torch.mean((components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] - components_ecog['amplitude_formants_hamon'])**2 * weight)}) + + # if key in ['freq_formants_hamon_hz']: + # # weight = components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # weight = components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # if False:#self.ghm_loss: + # diff = 50*self.lae5(components_guide[key][:,:self.n_formants_ecog]/400 , components_ecog[key]/400, reweight=weight) + # # diff = 15*self.lae5(components_guide[key][:,:self.n_formants_ecog]/400 , components_ecog[key]/400, reweight=weight) + # else: + # # diff = 300*torch.mean((components_guide['freq_formants_hamon'][:,:2] - components_ecog['freq_formants_hamon'][:,:2])**2 * weight) + # # diff = 300*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # diff = 100*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = 30*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]*10 - components_ecog[key]*10)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + + if key in ['freq_formants_hamon']: + weight = components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # weight = components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + if False:#self.ghm_loss: + diff = 50*self.lae5(components_guide[key][:,:self.n_formants_ecog]/400 , components_ecog[key]/400, reweight=weight) + # diff = 15*self.lae5(components_guide[key][:,:self.n_formants_ecog]/400 , components_ecog[key]/400, reweight=weight) + else: + diff = alpha['freq_formants_hamon']*300*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key][:,:self.n_formants_ecog])**2 * weight) + # diff = 300*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # diff = 100*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # diff = 30*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]*10 - components_ecog[key]*10)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + # if key in ['bandwidth_formants_hamon']: + # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]/4 - components_ecog[key]/4)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]/4 - components_ecog[key]/4)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + diff = self.flooding(diff,alpha['freq_formants_hamon']*betas['freq_formants_hamon']) + tracker.update({'freq_formants_hamon_hz_metric_2' : torch.mean((components_guide['freq_formants_hamon_hz'][:,:2]/400 - components_ecog['freq_formants_hamon_hz'][:,:2]/400)**2 * weight)}) + tracker.update({','+str(self.n_formants_ecog) : torch.mean((components_guide['freq_formants_hamon_hz'][:,:self.n_formants_ecog]/400 - components_ecog['freq_formants_hamon_hz'][:,:self.n_formants_ecog]/400)**2 * weight)}) + + if key in ['amplitude_formants_noise']: + weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight * loudness_db_norm + if False:#self.ghm_loss: + diff = self.lae6(components_guide[key],components_ecog[key],reweight=weight) + else: + # diff = 40*torch.mean((torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1) - components_ecog[key])**2 *weight)/2 \ + # + 40*torch.mean((torchaudio.transforms.AmplitudeToDB()(torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1))/100 - torchaudio.transforms.AmplitudeToDB()(components_ecog[key])/100)**2 * weight)/2 + diff = alpha['amplitude_formants_noise']*40*torch.mean((torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1) - components_ecog[key])**2 *weight) + diff = self.flooding(diff,alpha['amplitude_formants_noise']*betas['amplitude_formants_noise']) + tracker.update({'amplitude_formants_noise_metric': torch.mean((torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1) - components_ecog[key])**2 *weight)}) + + if key in ['freq_formants_noise']: + weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight* loudness_db_norm + if False:#self.ghm_loss: + diff = 10*self.lae7(components_guide[key][:,-self.n_formants_noise:]/400,components_ecog[key][:,-self.n_formants_noise:]/400,reweight=weight) + else: + + # diff = 30*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + diff = alpha['freq_formants_noise']*12000*torch.mean((components_guide[key][:,-self.n_formants_noise:] - components_ecog[key][:,-self.n_formants_noise:])**2 * weight) + # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_noise'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight) + diff = self.flooding(diff,alpha['freq_formants_noise']*betas['freq_formants_noise']) + tracker.update({'freq_formants_noise_metic': torch.mean((components_guide['freq_formants_noise_hz'][:,-self.n_formants_noise:]/2000*5 - components_ecog['freq_formants_noise_hz'][:,-self.n_formants_noise:]/2000*5)**2 * weight)}) + + # if key in ['freq_formants_noise_hz']: + # weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight* loudness_db_norm + # if False:#self.ghm_loss: + # diff = 10*self.lae7(components_guide[key][:,-self.n_formants_noise:]/400,components_ecog[key][:,-self.n_formants_noise:]/400,reweight=weight) + # else: + + # # diff = 30*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # diff = 3*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_noise'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight) + # diff = self.flooding(diff,betas['freq_formants_noise_hz']) + # tracker.update({'freq_formants_noise_hz_metic': torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight)}) + + if key in ['bandwidth_formants_noise_hz']: + weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight* loudness_db_norm + if False:#self.ghm_loss: + diff = 3*self.lae8(components_guide[key][:,-self.n_formants_noise:]/2000*5, components_ecog[key][:,-self.n_formants_noise:]/2000*5,reweight=weight) + else: + # diff = 30*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + diff = alpha['bandwidth_formants_noise_hz']*3*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + diff = self.flooding(diff,alpha['bandwidth_formants_noise_hz']*betas['bandwidth_formants_noise_hz']) + tracker.update({'bandwidth_formants_noise_hz_metic': torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight)}) + + # if key in ['bandwidth_formants_noise']: + # weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight* loudness_db_norm + # if False:#self.ghm_loss: + # diff = 3*self.lae8(components_guide[key][:,-self.n_formants_noise:], components_ecog[key][:,-self.n_formants_noise:],reweight=weight) + # else: + # # diff = 30*torch.mean((components_guide[key][:,-self.n_formants_noise:] - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # diff = 300*torch.mean((components_guide[key][:,-self.n_formants_noise:] - components_ecog[key][:,-self.n_formants_noise:])**2 * weight) + # diff = self.flooding(diff,betas['bandwidth_formants_noise']) + # tracker.update({'bandwidth_formants_noise_metic': torch.mean((components_guide['bandwidth_formants_noise_hz'][:,-self.n_formants_noise:]/2000*5 - components_ecog['bandwidth_formants_noise_hz'][:,-self.n_formants_noise:]/2000*5)**2 * weight)}) + tracker.update({key : diff}) + Lcomp += diff + # import pdb; pdb.set_trace() + + Loss = Lrec+Lcomp + + hamonic_components_diff = compdiffd2(components_ecog['freq_formants_hamon_hz']*1.5) + compdiffd2(components_ecog['f0_hz']*2) + compdiff(components_ecog['bandwidth_formants_noise_hz'][:,components_ecog['freq_formants_hamon_hz'].shape[1]:]/5) + compdiff(components_ecog['freq_formants_noise_hz'][:,components_ecog['freq_formants_hamon_hz'].shape[1]:]/5)+ compdiff(components_ecog['amplitudes'])*750. + Ldiff = torch.mean(hamonic_components_diff)/2000. + tracker.update(dict(Ldiff=Ldiff)) + Loss += Ldiff + + freq_linear_reweighting = 1 + thres = int(hz2ind(4000,self.n_fft)) if self.wavebased else mel_scale(self.spec_chans,4000,pt=False).astype(np.int32) + explosive=torch.sign(torch.mean((spec*freq_linear_reweighting)[...,thres:],dim=-1)-torch.mean((spec*freq_linear_reweighting)[...,:thres],dim=-1))*0.5+0.5 + Lexp = torch.mean((components_ecog['amplitudes'][:,0:1]-components_ecog['amplitudes'][:,1:2])*explosive)*100 + tracker.update(dict(Lexp=Lexp)) + Loss += Lexp + + Lfreqorder = torch.mean(F.relu(components_ecog['freq_formants_hamon_hz'][:,:-1]-components_ecog['freq_formants_hamon_hz'][:,1:])) + Loss += Lfreqorder + + return Loss,tracker + # ''' + # loss_weights_dict = {'Lrec':1,'loudness':15,'f0_hz':0.3,'amplitudes':30,\ + # 'amplitude_formants_hamon':40, 'freq_formants_hamon_hz':200,'amplitude_formants_noise':40,\ + # 'freq_formants_noise_hz':3,'bandwidth_formants_noise_hz':1,'Ldiff':1/2000.,\ + # 'Lexp':100, 'Lfreqorder':1 } + # self.encoder.requires_grad_(False) + # rec,components_ecog = self.generate_fromecog(ecog,mask_prior,mni=mni,return_components=True) + + # ###### + # if self.spec_sup: + # if False:#self.ghm_loss: + # Lrec = 0.3*self.lae1(rec,spec,tracker=tracker) + # else: + # Lrec = loss_weights_dict['Lrec']*self.lae(rec,spec,tracker=tracker)#torch.mean((rec - spec)**2) + # # Lamp = 10*torch.mean(F.relu(-components_ecog['amplitude_formants_hamon'][:,0:min(3,self.n_formants_ecog-1)]+components_ecog['amplitude_formants_hamon'][:,1:min(4,self.n_formants_ecog)])*(components_ecog['amplitudes'][:,0:1]>components_ecog['amplitudes'][:,1:2]).float()) + # # tracker.update(dict(Lamp=Lamp)) + # # Lrec+=Lamp + # else: + # Lrec = torch.tensor([0.0])# + # # Lrec = torch.mean((rec - spec).abs()) + # tracker.update(dict(Lrec=Lrec)) + # Lcomp = 0 + # if encoder_guide: + # components_guide = self.encode(spec,x_denoise=x_denoise,duomask=duomask,noise_level = F.softplus(self.decoder.bgnoise_amp)*self.decoder.noise_dist.mean(),x_amp=x_amp) + # consonant_weight = 1#100*(torch.sign(components_guide['amplitudes'][:,1:]-0.5)*0.5+0.5) + # if self.power_synth: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # else: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # # loudness_db_norm = (loudness_db.clamp(min=-35)+35)/25 + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']**2) + # #loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # for key in ['loudness','f0_hz','amplitudes','amplitude_formants_hamon','freq_formants_hamon_hz','amplitude_formants_noise','freq_formants_noise_hz','bandwidth_formants_noise']: + # # if 'hz' in key: + # # continue + # #''' + # if key == 'loudness': + # if self.power_synth: + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+70)/50 + # else: + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+70)/50 + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+35)/25 + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key]**2)+70)/50 + # if False:#self.ghm_loss: + # diff = self.lae2(loudness_db_norm, loudness_db_norm_ecog) + # else: + # diff = loss_weights_dict[key]*torch.mean((loudness_db_norm - loudness_db_norm_ecog)**2)#+ torch.mean((components_guide[key] - components_ecog[key])**2 * on_stage * consonant_weight) + # #''' + # if key == 'loudness': + # if self.power_synth: + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+70)/50 + # else: + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+70)/50 + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+35)/25 + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key]**2)+70)/50 + # if False:#self.ghm_loss: + # diff = self.lae2(loudness_db_norm, loudness_db_norm_ecog) + # else: + # diff = alpha['loudness']*loss_weights_dict[key]*torch.mean((loudness_db_norm - loudness_db_norm_ecog)**2)#+ torch.mean((components_guide[key] - components_ecog[key])**2 * on_stage * consonant_weight) + # diff = self.flooding(diff,alpha['loudness']*betas['loudness']) + # tracker.update({'loudness_metric' : torch.mean((loudness_db_norm - loudness_db_norm_ecog)**2*on_stage_wider)}) + + # if key == 'f0_hz': + # # diff = torch.mean((components_guide[key]*6 - components_ecog[key]*6)**2 * on_stage_wider * components_guide['loudness']/4) + # diff = loss_weights_dict[key]*torch.mean((components_guide[key]/200*5 - components_ecog[key]/200*5)**2 * on_stage_wider * loudness_db_norm) + # if key in ['amplitudes']: + # # if key in ['amplitudes','amplitudes_h']: + # weight = on_stage_wider * loudness_db_norm + # if self.ghm_loss: + # # diff = 100*self.lae3(components_guide[key], components_ecog[key],reweight=weight) + # diff = loss_weights_dict[key]*self.lae3(components_guide[key], components_ecog[key],reweight=weight) + # else: + # diff = 10*torch.mean((components_guide[key] - components_ecog[key])**2 *weight) + # if key in ['amplitude_formants_hamon']: + # weight = components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # if False:#self.ghm_loss: + # diff = 40*self.lae4(components_guide[key][:,:self.n_formants_ecog], components_ecog[key],reweight=weight) + # # diff = 10*self.lae4(components_guide[key][:,:self.n_formants_ecog], components_ecog[key],reweight=weight) + # else: + # # diff = 100*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # # diff = 40*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight)/2 \ + # # + 40*torch.mean((torchaudio.transforms.AmplitudeToDB()(components_guide[key][:,:self.n_formants_ecog])/100 - torchaudio.transforms.AmplitudeToDB()(components_ecog[key])/100)**2 * weight)/2 + # diff = loss_weights_dict[key]*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # # diff = 10*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # if key in ['freq_formants_hamon']: + # # diff = torch.mean((components_guide[key][:,:1]*10 - components_ecog[key][:,:1]*10)**2 * components_guide['amplitude_formants_hamon'][:,:1] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm ) + # # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]*10 - components_ecog[key]*10)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + + # if key in ['freq_formants_hamon_hz']: + # # weight = components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # weight = components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # if False:#self.ghm_loss: + # diff = 50*self.lae5(components_guide[key][:,:self.n_formants_ecog]/400 , components_ecog[key]/400, reweight=weight) + # # diff = 15*self.lae5(components_guide[key][:,:self.n_formants_ecog]/400 , components_ecog[key]/400, reweight=weight) + # else: + # # diff = 300*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # diff = loss_weights_dict[key]*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = 30*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]*10 - components_ecog[key]*10)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + # # if key in ['bandwidth_formants_hamon']: + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]/4 - components_ecog[key]/4)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + # # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]/4 - components_ecog[key]/4)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + + # if key in ['amplitude_formants_noise']: + # weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight * loudness_db_norm + # if False:#self.ghm_loss: + # diff = self.lae6(components_guide[key],components_ecog[key],reweight=weight) + # else: + # # diff = 40*torch.mean((torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1) - components_ecog[key])**2 *weight)/2 \ + # # + 40*torch.mean((torchaudio.transforms.AmplitudeToDB()(torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1))/100 - torchaudio.transforms.AmplitudeToDB()(components_ecog[key])/100)**2 * weight)/2 + # diff = loss_weights_dict[key]*torch.mean((torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1) - components_ecog[key])**2 *weight) + + # if key in ['freq_formants_noise_hz']: + # weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight* loudness_db_norm + # if False:#self.ghm_loss: + # diff = 10*self.lae7(components_guide[key][:,-self.n_formants_noise:]/400,components_ecog[key][:,-self.n_formants_noise:]/400,reweight=weight) + # else: + # # diff = 30*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # diff = loss_weights_dict[key]*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_noise'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight) + # if key in ['bandwidth_formants_noise_hz']: + # weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight* loudness_db_norm + # if False:#self.ghm_loss: + # diff = 3*self.lae8(components_guide[key][:,-self.n_formants_noise:]/2000*5, components_ecog[key][:,-self.n_formants_noise:]/2000*5,reweight=weight) + # else: + # # diff = 30*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # diff = loss_weights_dict[key]*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # tracker.update({key : diff}) + # Lcomp += diff + # # import pdb; pdb.set_trace() + + # Loss = Lrec+Lcomp + + # hamonic_components_diff = compdiffd2(components_ecog['freq_formants_hamon_hz']*1.5) + compdiffd2(components_ecog['f0_hz']*2) + compdiff(components_ecog['bandwidth_formants_noise_hz'][:,components_ecog['freq_formants_hamon_hz'].shape[1]:]/5) + compdiff(components_ecog['freq_formants_noise_hz'][:,components_ecog['freq_formants_hamon_hz'].shape[1]:]/5)+ compdiff(components_ecog['amplitudes'])*750. + # Ldiff = loss_weights_dict['Ldiff']*torch.mean(hamonic_components_diff) + # tracker.update(dict(Ldiff=Ldiff)) + # Loss += Ldiff + + # freq_linear_reweighting = 1 + # thres = int(hz2ind(4000,self.n_fft)) if self.wavebased else mel_scale(self.spec_chans,4000,pt=False).astype(np.int32) + # explosive=torch.sign(torch.mean((spec*freq_linear_reweighting)[...,thres:],dim=-1)-torch.mean((spec*freq_linear_reweighting)[...,:thres],dim=-1))*0.5+0.5 + # Lexp = loss_weights_dict['Lexp']*torch.mean((components_ecog['amplitudes'][:,0:1]-components_ecog['amplitudes'][:,1:2])*explosive) + # tracker.update(dict(Lexp=Lexp)) + # Loss += Lexp + + # Lfreqorder = loss_weights_dict['Lfreqorder']*torch.mean(F.relu(components_ecog['freq_formants_hamon_hz'][:,:-1]-components_ecog['freq_formants_hamon_hz'][:,1:])) + # Loss += Lfreqorder + # #print ('tracker content:',tracker,tracker.keys()) + # return Loss,tracker + # ''' + + ######### new balanced loss + # if self.spec_sup: + # if False:#self.ghm_loss: + # Lrec = 0.3*self.lae1(rec,spec,tracker=tracker) + # else: + # Lrec = self.lae(rec,spec,tracker=tracker)#torch.mean((rec - spec)**2) + # # Lamp = 10*torch.mean(F.relu(-components_ecog['amplitude_formants_hamon'][:,0:min(3,self.n_formants_ecog-1)]+components_ecog['amplitude_formants_hamon'][:,1:min(4,self.n_formants_ecog)])*(components_ecog['amplitudes'][:,0:1]>components_ecog['amplitudes'][:,1:2]).float()) + # # tracker.update(dict(Lamp=Lamp)) + # # Lrec+=Lamp + # else: + # Lrec = torch.tensor([0.0])# + # # Lrec = torch.mean((rec - spec).abs()) + # tracker.update(dict(Lrec=Lrec)) + # Lcomp = 0 + # if encoder_guide: + # components_guide = self.encode(spec,x_denoise=x_denoise,duomask=duomask,noise_level = F.softplus(self.decoder.bgnoise_amp)*self.decoder.noise_dist.mean(),x_amp=x_amp) + # consonant_weight = 1#100*(torch.sign(components_guide['amplitudes'][:,1:]-0.5)*0.5+0.5) + # if self.power_synth: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # else: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # # loudness_db_norm = (loudness_db.clamp(min=-35)+35)/25 + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']**2) + # for key in ['loudness','f0_hz','amplitudes','amplitude_formants_hamon','freq_formants_hamon_hz','amplitude_formants_noise','freq_formants_noise_hz','bandwidth_formants_noise_hz']: + # # if 'hz' in key: + # # continue + # if key == 'loudness': + # if self.power_synth: + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+70)/50 + # else: + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+35)/25 + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog[key])+70)/50 + # if False:#self.ghm_loss: + # diff = self.lae2(loudness_db_norm, loudness_db_norm_ecog) + # else: + # diff = 10*torch.mean((loudness_db_norm - loudness_db_norm_ecog)**2*on_stage_wider) + 1.5*10**6*torch.mean((components_guide['loudness'] - components_ecog['loudness'])**2*on_stage_wider) + 2*10**7*(components_ecog['loudness']**2*(1-on_stage_wider)).mean()#+ torch.mean((components_guide[key] - components_ecog[key])**2 * on_stage * consonant_weight) + # if key == 'f0_hz': + # # diff = torch.mean((components_guide[key]*6 - components_ecog[key]*6)**2 * on_stage_wider * components_guide['loudness']/4) + # diff = torch.mean((components_guide[key]/200*5 - components_ecog[key]/200*5)**2 * on_stage_wider * loudness_db_norm) + # if key in ['amplitudes']: + # # if key in ['amplitudes','amplitudes_h']: + # weight = on_stage_wider * loudness_db_norm + # if self.ghm_loss: + # # diff = 100*self.lae3(components_guide[key], components_ecog[key],reweight=weight) + # diff = 30*self.lae3(components_guide[key], components_ecog[key],reweight=weight) + # else: + # diff = 10*torch.mean((components_guide[key] - components_ecog[key])**2 *weight) + # if key in ['amplitude_formants_hamon']: + # weight = components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # if False:#self.ghm_loss: + # diff = 40*self.lae4(components_guide[key][:,:self.n_formants_ecog], components_ecog[key],reweight=weight) + # # diff = 10*self.lae4(components_guide[key][:,:self.n_formants_ecog], components_ecog[key],reweight=weight) + # else: + # # diff = 100*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # # diff = 40*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight)/2 \ + # # + 40*torch.mean((torchaudio.transforms.AmplitudeToDB()(components_guide[key][:,:self.n_formants_ecog])/100 - torchaudio.transforms.AmplitudeToDB()(components_ecog[key])/100)**2 * weight)/2 + # diff = 20*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # # diff = 10*torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # if key in ['freq_formants_hamon']: + # # diff = torch.mean((components_guide[key][:,:1]*10 - components_ecog[key][:,:1]*10)**2 * components_guide['amplitude_formants_hamon'][:,:1] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm ) + # # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]*10 - components_ecog[key]*10)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + + # if key in ['freq_formants_hamon_hz']: + # # weight = components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # weight = components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # if False:#self.ghm_loss: + # diff = 50*self.lae5(components_guide[key][:,:self.n_formants_ecog]/400 , components_ecog[key]/400, reweight=weight) + # # diff = 15*self.lae5(components_guide[key][:,:self.n_formants_ecog]/400 , components_ecog[key]/400, reweight=weight) + # else: + # diff = 150*torch.mean((components_guide['freq_formants_hamon'][:,:self.n_formants_ecog] - components_ecog['freq_formants_hamon'][:,:self.n_formants_ecog])**2 * weight) \ + # + 5*torch.mean((components_guide['freq_formants_hamon_hz'][:,:self.n_formants_ecog]/400 - components_ecog['freq_formants_hamon_hz'][:,:self.n_formants_ecog]/400)**2 * weight) + # # diff = 300*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = 100*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = 30*torch.mean((components_guide[key][:,:self.n_formants_ecog]/2000*5 - components_ecog[key][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]*10 - components_ecog[key]*10)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + # # if key in ['bandwidth_formants_hamon']: + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]/4 - components_ecog[key]/4)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + # # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog]/4 - components_ecog[key]/4)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + + # if key in ['amplitude_formants_noise']: + # weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight * loudness_db_norm + # if False:#self.ghm_loss: + # diff = self.lae6(components_guide[key],components_ecog[key],reweight=weight) + # else: + # # diff = 40*torch.mean((torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1) - components_ecog[key])**2 *weight)/2 \ + # # + 40*torch.mean((torchaudio.transforms.AmplitudeToDB()(torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1))/100 - torchaudio.transforms.AmplitudeToDB()(components_ecog[key])/100)**2 * weight)/2 + # diff = 40*torch.mean((torch.cat([components_guide[key][:,:self.n_formants_ecog],components_guide[key][:,-self.n_formants_noise:]],dim=1) - components_ecog[key])**2 *weight) + + # if key in ['freq_formants_noise_hz']: + # weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight* loudness_db_norm + # if False:#self.ghm_loss: + # diff = 10*self.lae7(components_guide[key][:,-self.n_formants_noise:]/400,components_ecog[key][:,-self.n_formants_noise:]/400,reweight=weight) + # else: + # # diff = 30*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # diff = 1.5*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # # diff = torch.mean((components_guide[key][:,:self.n_formants_ecog] - components_ecog[key])**2 * components_guide['amplitude_formants_noise'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight) + # if key in ['bandwidth_formants_noise_hz']: + # weight = components_guide['amplitudes'][:,1:2] * on_stage_wider * consonant_weight* loudness_db_norm + # if False:#self.ghm_loss: + # diff = 3*self.lae8(components_guide[key][:,-self.n_formants_noise:]/2000*5, components_ecog[key][:,-self.n_formants_noise:]/2000*5,reweight=weight) + # else: + # # diff = 30*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + # diff = 4*torch.mean((components_guide[key][:,-self.n_formants_noise:]/2000*5 - components_ecog[key][:,-self.n_formants_noise:]/2000*5)**2 * weight) + + # if key in ['loudness','freq_formants_hamon_hz']: + # diff = diff*10. + # tracker.update({key : diff}) + # Lcomp += diff + # Lcomp = Lcomp/20. + # Loss = Lrec+Lcomp + + # hamonic_components_diff = compdiffd2(components_ecog['freq_formants_hamon_hz']*1.5) + compdiffd2(components_ecog['f0_hz']*2) + compdiff(components_ecog['bandwidth_formants_noise_hz'][:,components_ecog['freq_formants_hamon_hz'].shape[1]:]/5) + compdiff(components_ecog['freq_formants_noise_hz'][:,components_ecog['freq_formants_hamon_hz'].shape[1]:]/5)+ compdiff(components_ecog['amplitudes'])*750. + # Ldiff = torch.mean(hamonic_components_diff)/2000. + # tracker.update(dict(Ldiff=Ldiff)) + # Loss += Ldiff + + # freq_linear_reweighting = 1 + # thres = int(hz2ind(4000,self.n_fft)) if self.wavebased else mel_scale(self.spec_chans,4000,pt=False).astype(np.int32) + # explosive=torch.sign(torch.mean((spec*freq_linear_reweighting)[...,thres:],dim=-1)-torch.mean((spec*freq_linear_reweighting)[...,:thres],dim=-1))*0.5+0.5 + # Lexp = torch.mean((components_ecog['amplitudes'][:,0:1]-components_ecog['amplitudes'][:,1:2])*explosive)*100 + # tracker.update(dict(Lexp=Lexp)) + # Loss += Lexp + + # Lfreqorder = torch.mean(F.relu(components_ecog['freq_formants_hamon_hz'][:,:-1]-components_ecog['freq_formants_hamon_hz'][:,1:])) + # Loss += Lfreqorder + + # return Loss + + ##### fit f1 freq only + # components_guide = self.encode(spec,x_denoise=x_denoise,duomask=duomask,noise_level = F.softplus(self.decoder.bgnoise_amp)*self.decoder.noise_dist.mean(),x_amp=x_amp) + # if self.power_synth: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # else: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # #loudness_db_norm = (loudness_db.clamp(min=-35)+35)/25 + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']**2) + # # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # weight = components_guide['amplitudes'][:,0:1] * on_stage_wider * 1 * loudness_db_norm + # # weight = components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # if False:#self.ghm_loss: + # diff = 50*self.lae5(components_guide[freq_formants_hamon_hz][:,:self.n_formants_ecog]/400 , components_ecog[freq_formants_hamon_hz]/400, reweight=weight) + # # diff = 15*self.lae5(components_guide[freq_formants_hamon_hz][:,:self.n_formants_ecog]/400 , components_ecog[freq_formants_hamon_hz]/400, reweight=weight) + # else: + # diff = 75*torch.mean((components_guide['freq_formants_hamon'][:,:2] - components_ecog['freq_formants_hamon'][:,:2])**2 * weight) \ + # + 0.5*torch.mean((components_guide['freq_formants_hamon_hz'][:,:2]/400 - components_ecog['freq_formants_hamon_hz'][:,:2]/400)**2 * weight) + # # diff = 300*torch.mean((components_guide['freq_formants_hamon_hz'][:,:2]/2000*5 - components_ecog['freq_formants_hamon_hz'][:,:2]/2000*5)**2 * weight) + # # diff = 100*torch.mean((components_guide[freq_formants_hamon_hz][:,:self.n_formants_ecog]/2000*5 - components_ecog[freq_formants_hamon_hz][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = 30*torch.mean((components_guide[freq_formants_hamon_hz][:,:self.n_formants_ecog]/2000*5 - components_ecog[freq_formants_hamon_hz][:,:self.n_formants_ecog]/2000*5)**2 * weight) + # # diff = torch.mean((components_guide[freq_formants_hamon_hz][:,:self.n_formants_ecog]*10 - components_ecog[freq_formants_hamon_hz]*10)**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * components_guide['loudness']/4 * consonant_weight) + # # diff = torch.mean((components_guide[freq_formants_hamon_hz][:,:self.n_formants_ecog] - components_ecog[freq_formants_hamon_hz])**2 * components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_ecog['amplitudes'][:,0:1] * on_stage_wider * consonant_weight) + # Loss = diff + # return Loss + + #### fit loudness freq only + # components_guide = self.encode(spec,x_denoise=x_denoise,duomask=duomask,noise_level = F.softplus(self.decoder.bgnoise_amp)*self.decoder.noise_dist.mean(),x_amp=x_amp) + # if self.power_synth: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # else: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # #loudness_db_norm = (loudness_db.clamp(min=-35)+35)/25 + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']**2) + # # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # if self.power_synth: + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog['loudness'])+70)/50 + # else: + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog['loudness'])+35)/25 + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog['loudness'])+70)/50 + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog['loudness']**2)+70)/50 + # weight = components_guide['loudness'][:,0:1] * on_stage_wider * 1 * loudness_db_norm + # # weight = components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # diff = 3*torch.mean((loudness_db_norm - loudness_db_norm_ecog)**2) + 10**6*torch.mean((components_guide['loudness'] - components_ecog['loudness'])**2) + # Loss = diff + + # Lloudness = 10**7*(components_ecog['loudness']**2*(1-on_stage_wider)).mean() + # Loss+=Lloudness + # return Loss + + #### fit f1f2 freq and loudness + # components_guide = self.encode(spec,x_denoise=x_denoise,duomask=duomask,noise_level = F.softplus(self.decoder.bgnoise_amp)*self.decoder.noise_dist.mean(),x_amp=x_amp) + # if self.power_synth: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # else: + # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']) + # #loudness_db_norm = (loudness_db.clamp(min=-35)+35)/25 + # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # # loudness_db = torchaudio.transforms.AmplitudeToDB()(components_guide['loudness']**2) + # # loudness_db_norm = (loudness_db.clamp(min=-70)+70)/50 + # if self.power_synth: + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog['loudness'])+70)/50 + # else: + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog['loudness'])+35)/25 + # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog['loudness'])+70)/50 + # # loudness_db_norm_ecog = (torchaudio.transforms.AmplitudeToDB()(components_ecog['loudness']**2)+70)/50 + # weight = components_guide['loudness'][:,0:1] * on_stage_wider * 1 * loudness_db_norm + # # weight = components_guide['amplitude_formants_hamon'][:,:self.n_formants_ecog] * components_guide['amplitudes'][:,0:1] * on_stage_wider * consonant_weight * loudness_db_norm + # diff = 3*torch.mean((loudness_db_norm - loudness_db_norm_ecog)**2) + 10**6*torch.mean((components_guide['loudness'] - components_ecog['loudness'])**2) + # Loss = diff + + # Lloudness = 10**7*(components_ecog['loudness']**2*(1-on_stage_wider)).mean() + # Loss+=Lloudness + + # weight = on_stage_wider * 1 * loudness_db_norm + # diff = 75*torch.mean((components_guide['freq_formants_hamon'][:,:2] - components_ecog['freq_formants_hamon'][:,:2])**2 * weight) \ + # + 0.5*torch.mean((components_guide['freq_formants_hamon_hz'][:,:2]/400 - components_ecog['freq_formants_hamon_hz'][:,:2]/400)**2 * weight) + # Loss+=diff + # return Loss + + def lerp(self, other, betta,w_classifier=False): + if hasattr(other, 'module'): + other = other.module + with torch.no_grad(): + params = list(self.decoder.parameters()) + list(self.encoder.parameters()) + (list(self.ecog_encoder.parameters()) if self.with_ecog else []) + (list(self.decoder_mel.parameters()) if self.do_mel_guide else []) + other_param = list(other.decoder.parameters()) + list(other.encoder.parameters()) + (list(other.ecog_encoder.parameters()) if self.with_ecog else []) + (list(self.decoder_mel.parameters()) if self.do_mel_guide else []) + for p, p_other in zip(params, other_param): + p.data.lerp_(p_other.data, 1.0 - betta) +# \ No newline at end of file diff --git a/model_param.json b/model_param.json new file mode 100644 index 00000000..986c7ff4 --- /dev/null +++ b/model_param.json @@ -0,0 +1,37 @@ +{ + "conv_S_1":{ + "in_C": 1, + "out_C": 8, + "k_siz": 8 + }, + "conv_S_2":{ + "in_C": 8, + "out_C": 16, + "k_siz": 8 + }, + "conv_S_3":{ + "in_C": 16, + "out_C": 32, + "k_siz": 8 + }, + "conv_U_1":{ + "in_C": 32, + "out_C": 32, + "k_siz": 8 + }, + "conv_U_2":{ + "in_C": 32, + "out_C": 32, + "k_siz": 8 + }, + "conv_U_3":{ + "in_C": 32, + "out_C": 32, + "k_siz": 8 + }, + "conv_O_1":{ + "in_C": 32, + "out_C": 1, + "k_siz": 1 + } +} \ No newline at end of file diff --git a/net.py b/net.py index 76b93d2f..aced7ea0 100644 --- a/net.py +++ b/net.py @@ -17,6 +17,7 @@ import torch from torch import nn from torch.nn import functional as F +from torch.nn import Parameter as P from torch.nn import init from torch.nn.parameter import Parameter import numpy as np @@ -24,14 +25,38 @@ import math from registry import * +PARTIAL_SN = False +USE_SN = False +def sn(module,use_sn=USE_SN): + if use_sn: + return torch.nn.utils.spectral_norm(module) + else: + return module + def pixel_norm(x, epsilon=1e-8): return x * torch.rsqrt(torch.mean(x.pow(2.0), dim=1, keepdim=True) + epsilon) -def style_mod(x, style): - style = style.view(style.shape[0], 2, x.shape[1], 1, 1) - return torch.addcmul(style[:, 1], value=1.0, tensor1=x, tensor2=style[:, 0] + 1) +def style_mod(x, style1, style2=None, bias = True): + if style1.dim()==2: + style1 = style1.view(style1.shape[0], 2, x.shape[1], 1, 1) + elif style1.dim()==3: + style1 = style1.view(style1.shape[0], 2, x.shape[1], style1.shape[2], 1) + if style2 is None: + if bias: + return torch.addcmul(style1[:, 1], value=1.0, tensor1=x, tensor2=style1[:, 0] + 1) + else: + return x*(style1[:,0]+1) + else: + if style2.dim()==2: + style2 = style2.view(style2.shape[0], 2, x.shape[1], 1, 1) + elif style2.dim()==3: + style2 = style2.view(style2.shape[0], 2, x.shape[1], style2.shape[2], 1) + if bias: + return torch.addcmul(style1[:, 1]+style2[:, 1], value=1.0, tensor1=x, tensor2=(style1[:, 0] + 1)*(style2[:, 0] + 1)) + else: + return x*(style1[:,0]+1)*(style2[:,0]+1) def upscale2d(x, factor=2): @@ -45,6 +70,12 @@ def upscale2d(x, factor=2): def downscale2d(x, factor=2): return F.avg_pool2d(x, factor, factor) +class Downsample(nn.Module): + def __init__(self,scale_factor): + super(Downsample, self).__init__() + self.scale_factor = scale_factor + def forward(self,x): + return F.interpolate(x,scale_factor=1/self.scale_factor) class Blur(nn.Module): def __init__(self, channels): @@ -59,71 +90,419 @@ def __init__(self, channels): def forward(self, x): return F.conv2d(x, weight=self.weight, groups=self.groups, padding=1) +class AdaIN(nn.Module): + def __init__(self, latent_size,outputs,temporal_w=False,global_w=True,temporal_global_cat = False): + super(AdaIN, self).__init__() + self.instance_norm = nn.InstanceNorm2d(outputs,affine=False, eps=1e-8) + self.global_w = global_w + self.temporal_w = temporal_w + self.temporal_global_cat = temporal_global_cat and (temporal_w and global_w) + if temporal_w and global_w: + if self.temporal_global_cat: + self.style = sn(ln.Conv1d(2*latent_size, 2 * outputs,1,1,0,gain=1)) + else: + self.style = sn(ln.Conv1d(latent_size, 2 * outputs,1,1,0,gain=1)) + self.style_global = sn(ln.Linear(latent_size, 2 * outputs, gain=1)) + else: + if temporal_w: + self.style = sn(ln.Conv1d(latent_size, 2 * outputs,1,1,0,gain=1)) + if global_w: + self.style = sn(ln.Linear(latent_size, 2 * outputs, gain=1)) + def forward(self,x,w=None,w_global=None): + x = self.instance_norm(x) + if self.temporal_w and self.global_w: + if self.temporal_global_cat: + w = torch.cat((w,w_global.unsqueeze(2).repeat(1,1,w.shape[2])),dim=1) + x = style_mod(x,self.style(w)) + else: + x = style_mod(x,self.style(w),self.style_global(w_global)) + else: + x = style_mod(x,self.style(w)) + return x + +class INencoder(nn.Module): + def __init__(self, inputs,latent_size,temporal_w=False,global_w=True,temporal_global_cat = False,use_statistic=True): + super(INencoder, self).__init__() + self.temporal_w = temporal_w + self.global_w = global_w + self.temporal_global_cat = temporal_global_cat and (temporal_w and global_w) + self.use_statistic = use_statistic + self.instance_norm = nn.InstanceNorm2d(inputs,affine=False) + if global_w and not(temporal_w): + self.style = sn(ln.Linear((2 * inputs) if use_statistic else inputs , latent_size)) + if temporal_w and not(global_w): + self.style = sn(ln.Conv1d((2 * inputs) if use_statistic else inputs, latent_size,1,1,0)) + if temporal_w and global_w: + if self.temporal_global_cat: + self.style = sn(ln.Conv1d((4 * inputs) if use_statistic else inputs, 2*latent_size,1,1,0)) + else: + self.style_local = sn(ln.Conv1d((2 * inputs) if use_statistic else inputs, latent_size,1,1,0)) + self.style_global = sn(ln.Linear((2 * inputs) if use_statistic else inputs, latent_size)) + + def forward(self,x): + m_local = torch.mean(x, dim=[3], keepdim=True) + std_local = torch.sqrt(torch.mean((x - m_local) ** 2, dim=[3], keepdim=True)+1e-8) + m_global = torch.mean(x, dim=[2,3], keepdim=True) + std_global = torch.sqrt(torch.mean((x - m_global) ** 2, dim=[2,3], keepdim=True)+1e-8) + if self.use_statistic: + style_local = torch.cat((m_local,std_local),dim=1) + style_global = torch.cat((m_global,std_global),dim=1) + else: + style_local = x + style_global = x + x = self.instance_norm(x) + if self.global_w and not(self.temporal_w): + w = self.style(style_global.view(style_global.shape[0], style_global.shape[1])) + return x,w + if self.temporal_w and not(self.global_w): + w = self.style(style_local.view(style_local.shape[0], style_local.shape[1],style_local.shape[2])) + return x,w + if self.temporal_w and self.global_w: + if self.temporal_global_cat: + if self.use_statistic: + style = torch.cat((style_local,style_global.repeat(1,1,style_local.shape[2],1)),dim=1) + else: + style = style_local + w = self.style(style.view(style.shape[0], style.shape[1],style.shape[2])) + w_local = w[:,:w.shape[1]//2] + w_global = torch.mean(w[:,w.shape[1]//2:],dim=[2]) + else: + w_local = self.style_local(style_local.view(style_local.shape[0], style_local.shape[1],style_local.shape[2])) + if not self.use_statistic: + style_global = style_global.mean(dim=[2]) + w_global = self.style_global(style_global.view(style_global.shape[0], style_global.shape[1])) + return x,w_local,w_global + +class Attention(nn.Module): + def __init__(self, inputs,temporal_w=False,global_w=True,temporal_global_cat = False,attentional_style=False,decoding=True,latent_size=None,heads=1,demod=False): + super(Attention, self).__init__() + # Channel multiplier + self.inputs = inputs + self.temporal_w = temporal_w + self.global_w = global_w + self.decoding = decoding + self.attentional_style = attentional_style + self.att_denorm = 8 + self.heads = heads + self.demod = demod + self.theta = sn(ln.Conv2d(inputs, inputs // self.att_denorm, 1,1,0, bias=False)) + self.phi = sn(ln.Conv2d(inputs, inputs // self.att_denorm, 1,1,0, bias=False)) + self.g = sn(ln.Conv2d(inputs, inputs // 2, 1,1,0, bias=False)) + self.o = sn(ln.Conv2d(inputs // 2, inputs, 1,1,0, bias=False)) + if not attentional_style: + self.norm_theta = nn.InstanceNorm2d(inputs // self.att_denorm,affine=True) + self.norm_phi = nn.InstanceNorm2d(inputs // self.att_denorm,affine=True) + self.norm_g = nn.InstanceNorm2d(inputs // 2,affine=True) + else: + if decoding: + self.norm_theta = AdaIN(latent_size,inputs//self.att_denorm,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat) + self.norm_phi = AdaIN(latent_size,inputs//self.att_denorm,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat) + self.norm_g = AdaIN(latent_size,inputs//2,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat) + else: + self.norm_theta = INencoder(inputs//self.att_denorm,latent_size,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat) + self.norm_phi = INencoder(inputs//self.att_denorm,latent_size,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat) + self.norm_g = INencoder(inputs//2,latent_size,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat) + + if demod and attentional_style: + self.theta = ln.StyleConv2d(inputs, inputs // self.att_denorm,kernel_size=1,latent_size=latent_size,stride=1,padding=0, + bias=False, upsample=False,temporal_w=temporal_w,transform_kernel=False) + self.phi = ln.StyleConv2d(inputs, inputs // self.att_denorm,kernel_size=1,latent_size=latent_size,stride=1,padding=0, + bias=False, upsample=False,temporal_w=temporal_w,transform_kernel=False) + self.g = ln.StyleConv2d(inputs, inputs // 2,kernel_size=1,latent_size=latent_size,stride=1,padding=0, + bias=True, upsample=False,temporal_w=temporal_w,transform_kernel=False) + self.o = ln.Conv2d(inputs // 2, inputs, 1,1,0, bias=True) + + + # Learnable gain parameter + self.gamma = P(torch.tensor(0.), requires_grad=True) + + def forward(self, x, w_local=None,w_global=None): + # Apply convs + x = x.contiguous() + theta = self.theta(x) + phi = F.max_pool2d(self.phi(x), [2,2]) + g = F.max_pool2d(self.g(x), [2,2]) + if w_local is not None and w_local.dim()==3: + w_local_down = F.avg_pool1d(w_local, 2) + else: + w_local_down = w_local + if not self.demod: + theta = self.norm_theta(theta,w_local,w_global) if (self.attentional_style and self.decoding) else self.norm_theta(theta) + phi = self.norm_phi(phi,w_local_down,w_global) if (self.attentional_style and self.decoding) else self.norm_phi(phi) + g = self.norm_g(g,w_local_down,w_global) if (self.attentional_style and self.decoding) else self.norm_g(g) + if self.attentional_style and not self.decoding: + if self.temporal_w and self.global_w: + theta,w_theta_local,w_theta_global = theta + phi,w_phi_local,w_phi_global = phi + g,w_g_local,w_g_global = g + w_phi_local = F.interpolate(w_phi_local,scale_factor=2,mode='linear') + w_g_local = F.interpolate(w_g_local,scale_factor=2,mode='linear') + w_local = w_theta_local+w_phi_local+w_g_local + w_global = w_theta_global+w_phi_global+w_g_global + else: + theta,w_theta = theta + phi,w_phi = phi + g,w_g = g + if w_phi.dim()==3: + w_phi = F.interpolate(w_phi,scale_factor=2,mode='linear') + w_g = F.interpolate(w_g,scale_factor=2,mode='linear') + w = w_theta+w_phi+w_g + + # Perform reshapes + self.theta_ = theta.reshape(-1, self.inputs // self.att_denorm//self.heads, self.heads ,x.shape[2] * x.shape[3]) + self.phi_ = phi.reshape(-1, self.inputs // self.att_denorm//self.heads, self.heads, x.shape[2] * x.shape[3] // 4) + g = g.reshape(-1, self.inputs // 2//self.heads, self.heads, x.shape[2] * x.shape[3] // 4) + # Matmul and softmax to get attention maps + self.beta = F.softmax(torch.einsum('bchi,bchj->bhij',self.theta_, self.phi_), -1) + # self.beta = F.softmax(torch.bmm(self.theta_, self.phi_), -1) + # Attention map times g path + o = self.o(torch.einsum('bchj,bhij->bchi',g, self.beta).reshape(-1, self.inputs // 2, x.shape[2], x.shape[3])) + # o = self.o(torch.bmm(g, self.beta.transpose(1,2)).view(-1, self.inputs // 2, x.shape[2], x.shape[3])) + if (not self.attentional_style) or self.decoding: + return self.gamma * o + x + else: + if self.temporal_w and self.global_w: + return self.gamma * o + x, w_local, w_global + else: + return self.gamma * o + x, w + +class ToWLatent(nn.Module): + def __init__(self,inputs,latent_size,temporal_w=False,from_input=False): + super(ToWLatent,self).__init__() + self.temporal_w = temporal_w + self.from_input = from_input + if temporal_w: + self.style = sn(ln.Conv1d(inputs, latent_size,1,1,0),use_sn=USE_SN) + else: + self.style = sn(ln.Linear(inputs, latent_size),use_sn=USE_SN) + + def forward(self, x): + if not self.from_input: + if self.temporal_w: + m = torch.mean(x, dim=[3], keepdim=True) + std = torch.sqrt(torch.mean((x - m) ** 2, dim=[3], keepdim=True)+1e-8) + else: + m = torch.mean(x, dim=[2, 3], keepdim=True) + std = torch.sqrt(torch.mean((x - m) ** 2, dim=[2, 3], keepdim=True)+1e-8) + else: + std = x + if self.temporal_w: + w = self.style(std.view(std.shape[0], std.shape[1],std.shape[2])) + else: + w = self.style(std.view(std.shape[0], std.shape[1])) + return w + +class ECoGMappingBlock(nn.Module): + def __init__(self, inputs, outputs, kernel_size,dilation=1,fused_scale=True,residual=False,resample=[]): + super(ECoGMappingBlock, self).__init__() + self.residual = residual + self.inputs_resample = resample + self.dim_missmatch = (inputs!=outputs) + self.resample = resample + if not self.resample: + self.resample=1 + self.padding = list(np.array(dilation)*(np.array(kernel_size)-1)//2) + # self.padding = [dilation[i]*(kernel_size[i]-1)//2 for i in range(len(dilation))] + if residual: + self.norm1 = nn.GroupNorm(min(inputs,32),inputs) + else: + self.norm1 = nn.GroupNorm(min(outputs,32),outputs) + self.conv1 = sn(ln.Conv3d(inputs, outputs, kernel_size, self.resample, self.padding, dilation=dilation, bias=False)) + if self.inputs_resample or self.dim_missmatch: + self.convskip = sn(ln.Conv3d(inputs, outputs, kernel_size, self.resample, self.padding, dilation=dilation, bias=False)) + + self.conv2 = sn(ln.Conv3d(outputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False)) + self.norm2 = nn.GroupNorm(min(outputs,32),outputs) + + def forward(self,x): + if self.residual: + x = F.leaky_relu(self.norm1(x),0.2) + if self.inputs_resample or self.dim_missmatch: + # x_skip = F.avg_pool3d(x,self.resample,self.resample) + x_skip = self.convskip(x) + else: + x_skip = x + x = F.leaky_relu(self.norm2(self.conv1(x)),0.2) + x = self.conv2(x) + x = x_skip + x + else: + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + return x + + + +class DemodEncodeBlock(nn.Module): + def __init__(self, inputs, outputs, latent_size, last=False,fused_scale=True,temporal_w=False,temporal_samples=None,resample=False,spec_chans=None,attention=False,attentional_style=False,heads=1,channels=1): + super(DemodEncodeBlock, self).__init__() + self.last = last + self.temporal_w = temporal_w + self.attention = attention + self.resample = resample + self.fused_scale = False if temporal_w else fused_scale + self.attentional_style = attentional_style + self.fromrgb = FromRGB(channels, inputs,style = True,residual=False,temporal_w=temporal_w,latent_size=latent_size) + self.conv1 = sn(ln.Conv2d(inputs, inputs, 3, 1, 1, bias=True)) + self.style1 = ToWLatent(inputs,latent_size,temporal_w=temporal_w) + self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False) + self.blur = Blur(inputs) + if attention: + self.non_local = Attention(inputs,temporal_w=temporal_w,attentional_style=attentional_style,decoding=False,latent_size=latent_size,heads=heads) + if last: + if self.temporal_w: + self.conv_2 = sn(ln.Conv2d(inputs * spec_chans, outputs, 3, 1, 1, bias=True)) + else: + self.dense = sn(ln.Linear(inputs * temporal_samples * spec_chans, outputs, bias = True)) + else: + if resample and fused_scale: + self.conv_2 = sn(ln.Conv2d(inputs, outputs, 3, 2, 1, bias=True, transform_kernel=True)) + else: + self.conv_2 = sn(ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)) + self.style2 = ToWLatent(outputs,latent_size,temporal_w=temporal_w,from_input=last) + self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False) + + def forward(self, spec,x): + spec_feature,w0 = self.fromrgb(spec) + x = (x+spec_feature) if (x is not None) else spec_feature + if self.attention: + x = self.non_local(x) + if self.attentional_style: + x,w_attn = x + x = F.leaky_relu(self.conv1(x),0.2) + w1 = self.style1(x) + x = self.instance_norm_1(x) + + if self.last: + if self.temporal_w: + x = self.conv_2(x.view(x.shape[0], -1,x.shape[2])) + x = F.leaky_relu(x, 0.2) + w2 = self.style2(x.view(x.shape[0], x.shape[1],x.shape[2])) + else: + x = self.dense(x.view(x.shape[0], -1)) + x = F.leaky_relu(x, 0.2) + w2 = self.style2(x.view(x.shape[0], x.shape[1])) + + else: + x = F.leaky_relu(self.conv_2(self.blur(x))) + if not self.fused_scale: + x = downscale2d(x) + w2 = self.style2(x) + x = self.instance_norm_2(x) + + spec = F.avg_pool2d(spec,2,2) + w = (w0+w1+w2+w_attn) if self.attentional_style else (w0+w1+w2) + return spec,x,w class EncodeBlock(nn.Module): - def __init__(self, inputs, outputs, latent_size, last=False, fused_scale=True): + def __init__(self, inputs, outputs, latent_size, last=False,islast=False, fused_scale=True,temporal_w=False,global_w=True,temporal_global_cat = False,residual=False,resample=False,temporal_samples=None,spec_chans=None): super(EncodeBlock, self).__init__() - self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) + self.conv_1 = sn(ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)) # self.conv_1 = ln.Conv2d(inputs + (1 if last else 0), inputs, 3, 1, 1, bias=False) self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) - self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False) self.blur = Blur(inputs) self.last = last - self.fused_scale = fused_scale + self.islast = islast + self.fused_scale = False if temporal_w else fused_scale + self.residual = residual + self.resample=resample + self.temporal_w = temporal_w + self.global_w = global_w + self.temporal_global_cat = temporal_global_cat and (temporal_w and global_w) if last: - self.dense = ln.Linear(inputs * 4 * 4, outputs) + if self.temporal_w: + self.conv_2 = sn(ln.Conv2d(inputs * spec_chans, outputs, 3, 1, 1, bias=False)) + else: + self.dense = sn(ln.Linear(inputs * temporal_samples * spec_chans, outputs)) else: - if fused_scale: - self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) + if resample and self.fused_scale: + self.conv_2 = sn(ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True)) else: - self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) + self.conv_2 = sn(ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)) self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) - self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False) - self.style_1 = ln.Linear(2 * inputs, latent_size) - if last: - self.style_2 = ln.Linear(outputs, latent_size) - else: - self.style_2 = ln.Linear(2 * outputs, latent_size) + self.style_1 = INencoder(inputs,latent_size,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat,use_statistic=True) + self.style_2 = INencoder(outputs,latent_size,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat,use_statistic=not(last)) + # self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False) + + # if self.temporal_w: + # self.style_1 = sn(ln.Conv1d(2 * inputs, latent_size,1,1,0),use_sn=PARTIAL_SN) + # if last: + # self.style_2 = sn(ln.Conv1d(outputs, latent_size,1,1,0),use_sn=PARTIAL_SN) + # else: + # self.style_2 = sn(ln.Conv1d(2 * outputs, latent_size,1,1,0),use_sn=PARTIAL_SN) + # else: + # self.style_1 = sn(ln.Linear(2 * inputs, latent_size),use_sn=PARTIAL_SN) + # if last: + # self.style_2 = sn(ln.Linear(outputs, latent_size),use_sn=PARTIAL_SN) + # else: + # self.style_2 = sn(ln.Linear(2 * outputs, latent_size),use_sn=PARTIAL_SN) + + if residual and not islast: + if inputs==outputs: + if not resample: + self.skip = nn.Identity() + else: + self.skip = Downsample(scale_factor=2) + else: + if not resample: + self.skip = nn.Sequential( + sn(ln.Conv2d(inputs, outputs, 1, 1, 0, bias=False),use_sn=PARTIAL_SN), + nn.InstanceNorm2d(outputs, affine=True, eps=1e-8) + ) + else: + self.skip = nn.Sequential( + sn(ln.Conv2d(inputs, outputs, 1, 2, 0, bias=False, transform_kernel=True),use_sn=PARTIAL_SN), + nn.InstanceNorm2d(outputs, affine=True, eps=1e-8) + ) with torch.no_grad(): self.bias_1.zero_() self.bias_2.zero_() def forward(self, x): + if self.residual: + x = F.leaky_relu(x,0.2) + x_input = x x = self.conv_1(x) + self.bias_1 + if self.temporal_w and self.global_w: + x,w1_local,w1_global = self.style_1(x) + else: + x,w1 = self.style_1(x) x = F.leaky_relu(x, 0.2) - m = torch.mean(x, dim=[2, 3], keepdim=True) - std = torch.sqrt(torch.mean((x - m) ** 2, dim=[2, 3], keepdim=True)) - style_1 = torch.cat((m, std), dim=1) - - x = self.instance_norm_1(x) - if self.last: - x = self.dense(x.view(x.shape[0], -1)) + if self.temporal_w: + x = self.conv_2(x.view(x.shape[0], -1,x.shape[2])) + else: + x = self.dense(x.view(x.shape[0], -1)) x = F.leaky_relu(x, 0.2) - w1 = self.style_1(style_1.view(style_1.shape[0], style_1.shape[1])) - w2 = self.style_2(x.view(x.shape[0], x.shape[1])) + if self.temporal_w and self.global_w: + x,w2_local,w2_global = self.style_2(x) + else: + x,w2 = self.style_2(x) else: x = self.conv_2(self.blur(x)) - if not self.fused_scale: - x = downscale2d(x) x = x + self.bias_2 - x = F.leaky_relu(x, 0.2) - - m = torch.mean(x, dim=[2, 3], keepdim=True) - std = torch.sqrt(torch.mean((x - m) ** 2, dim=[2, 3], keepdim=True)) - style_2 = torch.cat((m, std), dim=1) - - x = self.instance_norm_2(x) - - w1 = self.style_1(style_1.view(style_1.shape[0], style_1.shape[1])) - w2 = self.style_2(style_2.view(style_2.shape[0], style_2.shape[1])) + if self.temporal_w and self.global_w: + x,w2_local,w2_global = self.style_2(x) + else: + x,w2 = self.style_2(x) + + if not self.fused_scale: + x = downscale2d(x) - return x, w1, w2 + if not self.islast: + if self.residual: + x = self.skip(x_input)+x + else: + x = F.leaky_relu(x, 0.2) + + if self.temporal_w and self.global_w: + return x, w1_local, w1_global, w2_local, w2_global + else: + return x, w1, w2 class DiscriminatorBlock(nn.Module): @@ -169,31 +548,60 @@ def forward(self, x): class DecodeBlock(nn.Module): - def __init__(self, inputs, outputs, latent_size, has_first_conv=True, fused_scale=True, layer=0): + def __init__(self, inputs, outputs, latent_size, has_first_conv=True, fused_scale=True, layer=0,temporal_w=False,global_w=True,temporal_global_cat = False,residual=False,resample = False): super(DecodeBlock, self).__init__() self.has_first_conv = has_first_conv self.inputs = inputs self.has_first_conv = has_first_conv + self.temporal_w = temporal_w + self.global_w = global_w + self.temporal_global_cat = temporal_global_cat and (temporal_w and global_w) self.fused_scale = fused_scale + self.residual =residual + self.resample = resample if has_first_conv: - if fused_scale: - self.conv_1 = ln.ConvTranspose2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) + if resample and fused_scale: + self.conv_1 = sn(ln.ConvTranspose2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True),use_sn=PARTIAL_SN) else: - self.conv_1 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) + self.conv_1 = sn(ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False),use_sn=PARTIAL_SN) self.blur = Blur(outputs) self.noise_weight_1 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) self.noise_weight_1.data.zero_() self.bias_1 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) - self.instance_norm_1 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8) - self.style_1 = ln.Linear(latent_size, 2 * outputs, gain=1) - - self.conv_2 = ln.Conv2d(outputs, outputs, 3, 1, 1, bias=False) + # self.instance_norm_1 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8) + # if temporal_w: + # self.style_1 = sn(ln.Conv1d(latent_size, 2 * outputs,1,1,0, gain=1),use_sn=PARTIAL_SN) + # self.style_2 = sn(ln.Conv1d(latent_size, 2 * outputs,1,1,0, gain=1),use_sn=PARTIAL_SN) + # else: + # self.style_1 = sn(ln.Linear(latent_size, 2 * outputs, gain=1),use_sn=PARTIAL_SN) + # self.style_2 = sn(ln.Linear(latent_size, 2 * outputs, gain=1),use_sn=PARTIAL_SN) + self.style_1 = AdaIN(latent_size,outputs,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat) + self.style_2 = AdaIN(latent_size,outputs,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat) + + self.conv_2 = sn(ln.Conv2d(outputs, outputs, 3, 1, 1, bias=False),use_sn=PARTIAL_SN) self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) self.noise_weight_2.data.zero_() self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) - self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8) - self.style_2 = ln.Linear(latent_size, 2 * outputs, gain=1) + # self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8) + + if residual and has_first_conv: + if inputs==outputs: + if not resample: + self.skip = nn.Identity() + else: + self.skip = nn.Upsample(scale_factor=2) + else: + if not resample: + self.skip = nn.Sequential( + sn(ln.Conv2d(inputs, outputs, 1, 1, 0, bias=False),use_sn=PARTIAL_SN), + nn.InstanceNorm2d(outputs, affine=True, eps=1e-8) + ) + else: + self.skip = nn.Sequential( + sn(ln.ConvTranspose2d(inputs, outputs, 1, 2, 0, bias=False, transform_kernel=True),use_sn=PARTIAL_SN), + nn.InstanceNorm2d(outputs, affine=True, eps=1e-8) + ) self.layer = layer @@ -201,8 +609,11 @@ def __init__(self, inputs, outputs, latent_size, has_first_conv=True, fused_scal self.bias_1.zero_() self.bias_2.zero_() - def forward(self, x, s1, s2, noise): + def forward(self, x, s1, s2, noise, s1_global=None, s2_global=None): if self.has_first_conv: + if self.residual: + x = F.leaky_relu(x) + x_input = x if not self.fused_scale: x = upscale2d(x) x = self.conv_1(x) @@ -220,11 +631,14 @@ def forward(self, x, s1, s2, noise): x = x + s * torch.exp(-x * x / (2.0 * s * s)) / math.sqrt(2 * math.pi) * 0.8 x = x + self.bias_1 - x = F.leaky_relu(x, 0.2) - - x = self.instance_norm_1(x) + # x = self.instance_norm_1(x) + # x = style_mod(x, self.style_1(s1)) + if self.temporal_w and self.global_w: + x = self.style_1(x,s1,s1_global) + else: + x = self.style_1(x,s1) - x = style_mod(x, self.style_1(s1)) + x = F.leaky_relu(x, 0.2) x = self.conv_2(x) @@ -237,45 +651,431 @@ def forward(self, x, s1, s2, noise): tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]])) else: s = math.pow(self.layer + 1, 0.5) - x = x + s * torch.exp(-x * x / (2.0 * s * s)) / math.sqrt(2 * math.pi) * 0.8 + x = x + s * torch.exp(-x * x / (2.0 * s * s)) / math.sqrt(2 * math.pi) * 0.8 x = x + self.bias_2 - x = F.leaky_relu(x, 0.2) - x = self.instance_norm_2(x) + # x = self.instance_norm_2(x) + # x = style_mod(x, self.style_2(s2)) + if self.temporal_w and self.global_w: + x = self.style_2(x,s2,s2_global) + else: + x = self.style_2(x,s2) - x = style_mod(x, self.style_2(s2)) + if self.residual: + if self.has_first_conv: + x = self.skip(x_input)+x + else: + x = F.leaky_relu(x, 0.2) return x +class DemodDecodeBlock(nn.Module): + def __init__(self, inputs, outputs, latent_size, has_first_conv=True, fused_scale = True, layer=0,temporal_w=False,attention=False,attentional_style=False,heads=1,channels=1): + super(DemodDecodeBlock, self).__init__() + self.has_first_conv = has_first_conv + self.inputs = inputs + self.has_first_conv = has_first_conv + self.temporal_w = temporal_w + self.attention = attention + self.layer = layer + if has_first_conv: + if fused_scale: + self.conv1 = ln.StyleConv2dtest(inputs, outputs, kernel_size=3, latent_size=latent_size, stride=2 , padding=1, + bias=True,upsample=True,temporal_w=temporal_w,transform_kernel=True,transpose = True) + else: + self.conv1 = ln.StyleConv2dtest(inputs, outputs, kernel_size=3, latent_size=latent_size, stride=1 , padding=1, + bias=True,upsample=True,temporal_w=temporal_w,transform_kernel=False,transpose = False) + self.conv2 = ln.StyleConv2dtest(outputs, outputs, kernel_size=3, latent_size=latent_size, stride=1 , padding=1, + bias=True,upsample=False,temporal_w=temporal_w,transform_kernel=False) + else: + self.conv1 = ln.StyleConv2dtest(inputs, outputs, kernel_size=3, latent_size=latent_size, stride=1 , padding=1, + bias=True,upsample=False,temporal_w=temporal_w,transform_kernel=False) + self.skip = ToRGB(outputs,channels,style=False,residual=False,temporal_w=temporal_w,latent_size=latent_size) + # self.skip = ToRGB(outputs,channels,style=True,residual=False,temporal_w=temporal_w,latent_size=latent_size) + if attention: + self.att = Attention(outputs,temporal_w=temporal_w,attentional_style=attentional_style,decoding=True,latent_size=latent_size,heads=heads,demod=True) + self.blur = Blur(channels) + + def forward(self,x,y,w,noise): + x = F.leaky_relu(self.conv1(x,w,noise=noise),0.2) + if not noise: + s = math.pow(self.layer + 1, 0.5) + x = x + s * torch.exp(-x * x / (2.0 * s * s)) / math.sqrt(2 * math.pi) * 0.8 + if self.has_first_conv: + x = F.leaky_relu(self.conv2(x,w,noise=noise),0.2) + if not noise: + x = x + s * torch.exp(-x * x / (2.0 * s * s)) / math.sqrt(2 * math.pi) * 0.8 + if self.attention: + x = F.leaky_relu(self.att(x,w)) + skip = self.skip(x,w) + if y is not None: + y = upscale2d(y) + y = self.blur(y) + # y = F.interpolate(y,scale_factor=2,mode='bilinear') + return (y+skip, x) if (y is not None) else (skip,x) + + +class FromECoG(nn.Module): + def __init__(self, outputs,residual=False): + super().__init__() + self.residual=residual + self.from_ecog = sn(ln.Conv3d(1, outputs, [9,1,1], 1, [4,0,0])) + + def forward(self, x): + x = self.from_ecog(x) + if not self.residual: + x = F.leaky_relu(x, 0.2) + return x class FromRGB(nn.Module): - def __init__(self, channels, outputs): + def __init__(self, channels, outputs,style = False,residual=False,temporal_w=False,latent_size=None): super(FromRGB, self).__init__() - self.from_rgb = ln.Conv2d(channels, outputs, 1, 1, 0) + self.residual=residual + self.from_rgb = sn(ln.Conv2d(channels, outputs, 1, 1, 0)) + self.style = style + self.temporal_w = temporal_w + if style: + self.stylelayer = ToWLatent(outputs,latent_size,temporal_w=temporal_w) def forward(self, x): x = self.from_rgb(x) - x = F.leaky_relu(x, 0.2) + if self.style: + w = self.stylelayer(x) + if not self.residual: + x = F.leaky_relu(x, 0.2) - return x + return x if not self.style else (x, w) class ToRGB(nn.Module): - def __init__(self, inputs, channels): + def __init__(self, inputs, channels,style = False,residual=False,temporal_w=False,latent_size=None): super(ToRGB, self).__init__() self.inputs = inputs self.channels = channels - self.to_rgb = ln.Conv2d(inputs, channels, 1, 1, 0, gain=0.03) + self.residual = residual + self.style = style + if style: + self.to_rgb = ln.StyleConv2dtest(inputs, channels, kernel_size=1, latent_size=latent_size, stride=1 , padding=0, gain=0.03, + bias=True,upsample=False,temporal_w=temporal_w,transform_kernel=False,demod=False) + else: + self.to_rgb = sn(ln.Conv2d(inputs, channels, 1, 1, 0, gain=0.03),use_sn=PARTIAL_SN) - def forward(self, x): - x = self.to_rgb(x) + def forward(self, x,w=None): + if self.residual: + x = F.leaky_relu(x, 0.2) + x = self.to_rgb(x,w) if (self.style and (w is not None) ) else self.to_rgb(x) return x +@ECOG_ENCODER.register("ECoGMappingDilation") +class ECoGMapping_Dilation(nn.Module): + def __init__(self, latent_size,average_w = False,temporal_w=False,global_w=True,attention=None,temporal_samples=None,attentional_style=False,heads=1): + super(ECoGMapping_Dilation, self).__init__() + self.temporal_w = temporal_w + self.global_w = global_w + self.from_ecog = FromECoG(16,residual=True) + self.conv1 = ECoGMappingBlock(16,32,[5,1,1],residual=True,dilation=[2,1,1]) + self.conv2 = ECoGMappingBlock(32,64,[3,1,1],residual=True,dilation = [4,1,1]) + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv3d(64,1,[3,1,1],1,[1,0,0]) + # self.mask = ln.Conv3d(64,1,[3,1,1],1,[4,0,0],dilation = [4,1,1]) + self.conv3 = ECoGMappingBlock(64,128,[3,3,3],residual=True,dilation = [8,2,2]) + self.conv4 = ECoGMappingBlock(128,256,[3,3,3],residual=True,dilation = [16,4,4]) + self.norm = nn.GroupNorm(32,256) + self.conv5 = ln.Conv1d(256,256,3,1,16,dilation=16) + if self.temporal_w: + self.norm2 = nn.GroupNorm(32,256) + self.conv6 = ln.Conv1d(256,256,3,1,1) + self.norm3 = nn.GroupNorm(32,256) + self.conv7 = ln.Conv1d(256,latent_size,3,1,1) + if self.global_w: + self.linear1 = ln.Linear(256*8,512) + self.linear2 = ln.Linear(512,latent_size) + def forward(self,ecog,mask_prior): + x_all_global = [] + x_all_local = [] + for d in range(len(ecog)): + x = ecog[d] + x = x.reshape([-1,1,x.shape[1],15,15]) + mask_prior_d = mask_prior[d].reshape(-1,1,1,15,15) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,8:-8] + if mask_prior is not None: + mask = mask*mask_prior_d + x = x[:,:,8:-8] + x = x*mask + x = self.conv3(x) + x = self.conv4(x) + x = x.max(-1)[0].max(-1)[0] + x_common = self.conv5(F.leaky_relu(self.norm(x),0.2)) + if self.global_w: + x_global = F.max_pool1d(x_common,16,16) + x_global = x_global.flatten(1) + x_global = self.linear1(F.leaky_relu(x_global,0.2)) + x_global = self.linear2(F.leaky_relu(x_global,0.2)) + x_global = F.leaky_relu(x_global,0.2) + x_all_global += [x_global] + if self.temporal_w: + x_local = self.conv6(F.leaky_relu(self.norm2(x_common),0.2)) + x_local = self.conv7(F.leaky_relu(self.norm3(x_local),0.2)) + x_local = F.leaky_relu(x_local,0.2) + x_all_local += [x_local] + if self.global_w and self.temporal_w: + x_all = (torch.cat(x_all_local,dim=0),torch.cat(x_all_global,dim=0)) + else: + if self.temporal_w: + x_all = torch.cat(x_all_local,dim=0) + else: + x_all = torch.cat(x_all_global,dim=0) + return x_all + +@ECOG_ENCODER.register("ECoGMappingBottleneck") +class ECoGMapping_Bottleneck(nn.Module): + def __init__(self, latent_size,average_w = False,temporal_w=False,global_w=True,attention=None,temporal_samples=None,attentional_style=False,heads=1): + super(ECoGMapping_Bottleneck, self).__init__() + self.temporal_w = temporal_w + self.global_w = global_w + self.from_ecog = FromECoG(16,residual=True) + self.conv1 = ECoGMappingBlock(16,32,[5,1,1],residual=True,resample = [2,1,1]) + self.conv2 = ECoGMappingBlock(32,64,[3,1,1],residual=True,resample = [2,1,1]) + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv3d(64,1,[3,1,1],1,[1,0,0]) + self.conv3 = ECoGMappingBlock(64,128,[3,3,3],residual=True,resample = [2,2,2]) + self.conv4 = ECoGMappingBlock(128,256,[3,3,3],residual=True,resample = [2,2,2]) + self.norm = nn.GroupNorm(32,256) + self.conv5 = ln.Conv1d(256,256,3,1,1) + if self.temporal_w: + self.norm2 = nn.GroupNorm(32,256) + self.conv6 = ln.ConvTranspose1d(256, 128, 3, 2, 1, transform_kernel=True) + self.norm3 = nn.GroupNorm(32,128) + self.conv7 = ln.ConvTranspose1d(128, 64, 3, 2, 1, transform_kernel=True) + self.norm4 = nn.GroupNorm(32,64) + self.conv8 = ln.ConvTranspose1d(64, 32, 3, 2, 1, transform_kernel=True) + self.norm5 = nn.GroupNorm(32,32) + self.conv9 = ln.ConvTranspose1d(32, latent_size, 3, 2, 1, transform_kernel=True) + if self.global_w: + self.linear1 = ln.Linear(256,128) + self.linear2 = ln.Linear(128,latent_size) + def forward(self,ecog,mask_prior): + x_all_global = [] + x_all_local = [] + for d in range(len(ecog)): + x = ecog[d] + x = x.reshape([-1,1,x.shape[1],15,15]) + mask_prior_d = mask_prior[d].reshape(-1,1,1,15,15) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,2:-2] + if mask_prior is not None: + mask = mask*mask_prior_d + x = x[:,:,2:-2] + x = x*mask + x = self.conv3(x) + x = self.conv4(x) + x = x.max(-1)[0].max(-1)[0] + x_common = self.conv5(F.leaky_relu(self.norm(x),0.2)) + if self.global_w: + x_global = x_common.max(-1)[0] + x_global = self.linear1(F.leaky_relu(x_global,0.2)) + x_global = self.linear2(F.leaky_relu(x_global,0.2)) + x_global = F.leaky_relu(x_global,0.2) + x_all_global += [x_global] + if self.temporal_w: + x_local = self.conv6(F.leaky_relu(self.norm2(x_common),0.2)) + x_local = self.conv7(F.leaky_relu(self.norm3(x_local),0.2)) + x_local = self.conv8(F.leaky_relu(self.norm4(x_local),0.2)) + x_local = self.conv9(F.leaky_relu(self.norm5(x_local),0.2)) + x_local = F.leaky_relu(x_local,0.2) + x_all_local += [x_local] + if self.global_w and self.temporal_w: + x_all = (torch.cat(x_all_local,dim=0),torch.cat(x_all_global,dim=0)) + else: + if self.temporal_w: + x_all = torch.cat(x_all_local,dim=0) + else: + x_all = torch.cat(x_all_global,dim=0) + return x_all + +@ECOG_ENCODER.register("ECoGMappingDefault") +class ECoGMapping(nn.Module): + def __init__(self, latent_size,average_w = False,temporal_w=False,global_w=True,attention=None,temporal_samples=None,attentional_style=False,heads=1): + super(ECoGMapping, self).__init__() + self.temporal_w = temporal_w + self.global_w = global_w + self.from_ecog = FromECoG(16,residual=True) + self.conv1 = ECoGMappingBlock(16,32,[5,1,1],residual=True,resample = [2,1,1]) + self.conv2 = ECoGMappingBlock(32,64,[3,1,1],residual=True,resample = [2,1,1]) + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv3d(64,1,[3,1,1],1,[1,0,0]) + self.conv3 = ECoGMappingBlock(64,128,[3,3,3],residual=True,resample = [2,2,2]) + self.conv4 = ECoGMappingBlock(128,256,[3,3,3],residual=True,resample = [2,2,2]) + self.norm = nn.GroupNorm(32,256) + self.conv5 = ln.Conv1d(256,256,3,1,1) + if self.temporal_w: + self.norm2 = nn.GroupNorm(32,256) + self.conv6 = ln.Conv1d(256,256,3,1,1) + self.norm3 = nn.GroupNorm(32,256) + self.conv7 = ln.Conv1d(256,latent_size,3,1,1) + if self.global_w: + self.linear1 = ln.Linear(256*8,512) + self.linear2 = ln.Linear(512,latent_size,gain=1) + def forward(self,ecog,mask_prior): + x_all_global = [] + x_all_local = [] + for d in range(len(ecog)): + x = ecog[d] + x = x.reshape([-1,1,x.shape[1],15,15]) + mask_prior_d = mask_prior[d].reshape(-1,1,1,15,15) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,2:-2] + if mask_prior is not None: + mask = mask*mask_prior_d + x = x[:,:,2:-2] + x = x*mask + x = self.conv3(x) + x = self.conv4(x) + x = x.max(-1)[0].max(-1)[0] + x_common = self.conv5(F.leaky_relu(self.norm(x),0.2)) + if self.global_w: + x_global = x_common.flatten(1) + x_global = self.linear1(F.leaky_relu(x_global,0.2)) + x_global = self.linear2(F.leaky_relu(x_global,0.2)) + x_all_global += [x_global] + if self.temporal_w: + x_local = self.conv6(F.leaky_relu(self.norm2(x_common),0.2)) + x_local = self.conv7(F.leaky_relu(self.norm3(x_local),0.2)) + x_all_local += [x_local] + if self.global_w and self.temporal_w: + x_all = (torch.cat(x_all_local,dim=0),torch.cat(x_all_global,dim=0)) + else: + if self.temporal_w: + x_all = torch.cat(x_all_local,dim=0) + else: + x_all = torch.cat(x_all_global,dim=0) + return x_all + + + + +@ENCODERS.register("EncoderDemod") +class Encoder_Demod(nn.Module): + def __init__(self, startf, maxf, layer_count, latent_size, channels=3,average_w = False,temporal_w=False,residual=False,attention=None,temporal_samples=None,spec_chans=None,attentional_style=False,heads=1): + super(Encoder_Demod, self).__init__() + self.maxf = maxf + self.startf = startf + self.layer_count = layer_count + self.channels = channels + self.latent_size = latent_size + self.average_w = average_w + self.temporal_w = temporal_w + self.attentional_style = attentional_style + mul = 2 + inputs = startf + self.encode_block = nn.ModuleList() + self.attention_block = nn.ModuleList() + resolution = 2 ** (self.layer_count + 1) + for i in range(self.layer_count): + outputs = min(self.maxf, startf * mul) + apply_attention = attention and attention[self.layer_count-i-1] + current_spec_chans = spec_chans // 2**i + current_temporal_samples = temporal_samples // 2**i + last = i==(self.layer_count-1) + fused_scale = resolution >= 128 + resolution //= 2 + block = DemodEncodeBlock(inputs, outputs, latent_size, last,temporal_w=temporal_w,fused_scale=fused_scale,resample=True,temporal_samples=current_temporal_samples,spec_chans=current_spec_chans, + attention=apply_attention,attentional_style=attentional_style,heads=heads,channels=channels) + #print("encode_block%d %s styles out: %d" % ((i + 1), millify(count_parameters(block)), inputs)) + self.encode_block.append(block) + inputs = outputs + mul *= 2 + + def encode(self, spec, lod): + if self.temporal_w: + styles = torch.zeros(spec.shape[0], 1, self.latent_size,128) + else: + styles = torch.zeros(spec.shape[0], 1, self.latent_size) + + x = None + for i in range(self.layer_count - lod - 1, self.layer_count): + spec, x, w = self.encode_block[i](spec,x) + if self.temporal_w and i!=0: + w = F.interpolate(w,scale_factor=2**i) + styles[:, 0] += w + if self.average_w: + styles /= (lod+1) + return styles + + def forward(self, x, lod, blend): + if blend == 1: + return self.encode(x, lod) + else: + return self.encode2(x, lod, blend) + + +@ENCODERS.register("EncoderFormant") +class FormantEncoder(nn.Module): + def __init__(self, n_mels=64, n_formants=4, k=30): + super(FormantEncoder, self).__init__() + self.n_mels = n_mels + self.conv1 = ln.Conv1d(n_mels,64,3,1,1) + self.norm1 = nn.GroupNorm(32,64) + self.conv2 = ln.Conv1d(64,128,3,1,1) + self.norm2 = nn.GroupNorm(32,128) + + self.conv_fundementals = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,128) + self.conv_f0 = ln.Conv1d(128,1,1,1,0) + self.conv_amplitudes = ln.Conv1d(128,2,1,1,0) + # self.conv_loudness = ln.Conv1d(128,1,1,1,0) + + self.conv_formants = ln.Conv1d(128,128,3,1,1) + self.norm_formants = nn.GroupNorm(32,128) + self.conv_formants_freqs = ln.Conv1d(128,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(128,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(128,n_formants,1,1,0) + + self.amplifier = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.amplifier,1.0) + + def forward(self,x): + x = x.squeeze(dim=1).permute(0,2,1) #B * f * T + loudness = torch.mean(x*0.5+0.5,dim=1,keepdim=True) + loudness = F.softplus(self.amplifier)*loudness + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x_common = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + + # loudness = F.relu(self.conv_loudness(x_common)) + amplitudes = F.softmax(self.conv_amplitudes(x_common),dim=1) + + x_fundementals = F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2) + # f0 = F.sigmoid(self.conv_f0(x_fundementals)) + # f0 = F.tanh(self.conv_f0(x_fundementals)) * (16/64)*(self.n_mels/64) # 72hz < f0 < 446 hz + f0 = F.sigmoid(self.conv_f0(x_fundementals)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = F.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = F.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = F.sigmoid(self.conv_formants_freqs(x_formants)) + formants_freqs = torch.cumsum(formants_freqs,dim=1) + formants_freqs = formants_freqs + # formants_freqs = formants_freqs + f0 + formants_bandwidth = F.sigmoid(self.conv_formants_bandwidth(x_formants)) + formants_amplitude = F.softmax(self.conv_formants_amplitude(x_formants),dim=1) + + return f0,loudness,amplitudes,formants_freqs,formants_bandwidth,formants_amplitude + @ENCODERS.register("EncoderDefault") class Encoder_old(nn.Module): - def __init__(self, startf, maxf, layer_count, latent_size, channels=3): + def __init__(self, startf, maxf, layer_count, latent_size, channels=3,average_w = False,temporal_w=False,global_w=True,temporal_global_cat = False,residual=False,attention=None,temporal_samples=None,spec_chans=None,attentional_style=False,heads=1): super(Encoder_old, self).__init__() self.maxf = maxf self.startf = startf @@ -283,50 +1083,115 @@ def __init__(self, startf, maxf, layer_count, latent_size, channels=3): self.from_rgb: nn.ModuleList[FromRGB] = nn.ModuleList() self.channels = channels self.latent_size = latent_size - + self.average_w = average_w + self.temporal_w = temporal_w + self.global_w = global_w + self.temporal_global_cat = temporal_global_cat + self.attentional_style = attentional_style mul = 2 inputs = startf self.encode_block: nn.ModuleList[EncodeBlock] = nn.ModuleList() - + self.attention_block = nn.ModuleList() resolution = 2 ** (self.layer_count + 1) for i in range(self.layer_count): outputs = min(self.maxf, startf * mul) - self.from_rgb.append(FromRGB(channels, inputs)) - + self.from_rgb.append(FromRGB(channels, inputs,residual=residual)) + apply_attention = attention and attention[self.layer_count-i-1] + non_local = Attention(inputs,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat,attentional_style=attentional_style,decoding=False,latent_size=latent_size,heads=heads) if apply_attention else None + self.attention_block.append(non_local) fused_scale = resolution >= 128 - - block = EncodeBlock(inputs, outputs, latent_size, False, fused_scale=fused_scale) + current_spec_chans = spec_chans // 2**i + current_temporal_samples = temporal_samples // 2**i + islast = i==(self.layer_count-1) + block = EncodeBlock(inputs, outputs, latent_size, False, islast, fused_scale=fused_scale,temporal_w=temporal_w,global_w=global_w,temporal_global_cat=temporal_global_cat,residual=residual,resample=True,temporal_samples=current_temporal_samples,spec_chans=current_spec_chans) resolution //= 2 #print("encode_block%d %s styles out: %d" % ((i + 1), millify(count_parameters(block)), inputs)) + self.encode_block.append(block) inputs = outputs mul *= 2 def encode(self, x, lod): - styles = torch.zeros(x.shape[0], 1, self.latent_size) + if self.temporal_w and self.global_w: + styles = torch.zeros(x.shape[0], 1, self.latent_size,128) + styles_global = torch.zeros(x.shape[0], 1, self.latent_size) + else: + if self.temporal_w: + styles = torch.zeros(x.shape[0], 1, self.latent_size,128) + else: + styles = torch.zeros(x.shape[0], 1, self.latent_size) x = self.from_rgb[self.layer_count - lod - 1](x) x = F.leaky_relu(x, 0.2) for i in range(self.layer_count - lod - 1, self.layer_count): - x, s1, s2 = self.encode_block[i](x) - styles[:, 0] += s1 + s2 - - return styles + if self.attention_block[i]: + x = self.attention_block[i](x) + if self.attentional_style: + if self.temporal_w and self.global_w: + x,s,s_global = x + else: + x,s = x + if self.temporal_w: + s = F.interpolate(s,scale_factor=2**i,mode='linear') + if self.temporal_w and self.global_w: + x, s1, s1_global, s2, s2_global = self.encode_block[i](x) + else: + x, s1, s2 = self.encode_block[i](x) + if self.temporal_w and i!=0: + s1 = F.interpolate(s1,scale_factor=2**i,mode='linear') + s2 = F.interpolate(s2,scale_factor=2**i,mode='linear') + if self.temporal_w and self.global_w: + styles_global[:, 0] += s1_global + s2_global + (s_global if (self.attention_block[i] and self.attentional_style) else 0) + styles[:, 0] += s1 + s2 + (s if (self.attention_block[i] and self.attentional_style) else 0) + if self.temporal_w and self.global_w: + styles_global[:, 0] += s1_global + s2_global + (s_global if (self.attention_block[i] and self.attentional_style) else 0) + if self.average_w: + styles /= (lod+1) + if self.temporal_w and self.global_w: + styles_global/=(lod+1) + + if self.temporal_w and self.global_w: + return styles,styles_global + else: + return styles def encode2(self, x, lod, blend): x_orig = x - styles = torch.zeros(x.shape[0], 1, self.latent_size) - + if self.temporal_w and self.global_w: + styles = torch.zeros(x.shape[0], 1, self.latent_size,128) + styles_global = torch.zeros(x.shape[0], 1, self.latent_size) + else: + if self.temporal_w: + styles = torch.zeros(x.shape[0], 1, self.latent_size,128) + else: + styles = torch.zeros(x.shape[0], 1, self.latent_size) x = self.from_rgb[self.layer_count - lod - 1](x) x = F.leaky_relu(x, 0.2) + if self.attention_block[self.layer_count - lod - 1]: + x = self.attention_block[self.layer_count - lod - 1](x) + if self.attentional_style: + if self.temporal_w and self.global_w: + x,s,s_global = x + else: + x,s = x + if self.temporal_w: + s = F.interpolate(s,scale_factor=2**(self.layer_count - lod - 1),mode='linear') + if self.temporal_w and self.global_w: + x, s1, s1_global, s2, s2_global = self.encode_block[self.layer_count - lod - 1](x) + else: + x, s1, s2 = self.encode_block[self.layer_count - lod - 1](x) + if self.temporal_w and (self.layer_count - lod - 1)!=0: + s1 = F.interpolate(s1,scale_factor=2**(self.layer_count - lod - 1),mode='linear') + s2 = F.interpolate(s2,scale_factor=2**(self.layer_count - lod - 1),mode='linear') + styles[:, 0] += s1 * blend + s2 * blend + (s*blend if (self.attention_block[self.layer_count - lod - 1] and self.attentional_style) else 0) + if self.temporal_w and self.global_w: + styles_global[:, 0] += s1_global * blend + s2_global * blend + (s_global*blend if (self.attention_block[self.layer_count - lod - 1] and self.attentional_style) else 0) - x, s1, s2 = self.encode_block[self.layer_count - lod - 1](x) - styles[:, 0] += s1 * blend + s2 * blend x_prev = F.avg_pool2d(x_orig, 2, 2) @@ -336,10 +1201,33 @@ def encode2(self, x, lod, blend): x = torch.lerp(x_prev, x, blend) for i in range(self.layer_count - (lod - 1) - 1, self.layer_count): - x, s1, s2 = self.encode_block[i](x) - styles[:, 0] += s1 + s2 - - return styles + if self.attention_block[i]: + x = self.attention_block[i](x) + if self.attentional_style: + if self.temporal_w and self.global_w: + x,s,s_global = x + else: + x,s = x + if self.temporal_w: + s = F.interpolate(s,scale_factor=2**i,mode='linear') + if self.temporal_w and self.global_w: + x, s1, s1_global, s2, s2_global = self.encode_block[i](x) + else: + x, s1, s2 = self.encode_block[i](x) + if self.temporal_w and i!=0: + s1 = F.interpolate(s1,scale_factor=2**i,mode='linear') + s2 = F.interpolate(s2,scale_factor=2**i,mode='linear') + styles[:, 0] += s1 + s2 + (s if (self.attention_block[i] and self.attentional_style) else 0) + if self.temporal_w and self.global_w: + styles_global[:, 0] += s1_global + s2_global + (s_global if (self.attention_block[i] and self.attentional_style) else 0) + if self.average_w: + styles /= (lod+1) + if self.temporal_w and self.global_w: + styles_global/=(lod+1) + if self.temporal_w and self.global_w: + return styles,styles_global + else: + return styles def forward(self, x, lod, blend): if blend == 1: @@ -674,10 +1562,147 @@ def forward(self, x, lod, blend): else: return self.encode2(x, lod, blend) +@GENERATORS.register("GeneratorDemod") +class Generator_Demod(nn.Module): + def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, channels=3, temporal_samples=128,spec_chans=128,temporal_w=False,init_zeros=False,residual=False,attention=None,attentional_style=False,heads=1): + super(Generator_Demod, self).__init__() + self.maxf = maxf + self.startf = startf + self.layer_count = layer_count + + self.channels = channels + self.latent_size = latent_size + self.temporal_w = temporal_w + self.init_zeros = init_zeros + self.attention = attention + self.attentional_style = attentional_style + mul = 2 ** (self.layer_count - 1) + inputs = min(self.maxf, startf * mul) + self.initial_inputs = inputs + self.init_specchans = spec_chans//2**(self.layer_count-1) + self.init_temporalsamples = temporal_samples//2**(self.layer_count-1) + self.layer_to_resolution = [0 for _ in range(layer_count)] + resolution = 2 + + self.style_sizes = [] + + to_rgb = nn.ModuleList() + self.attention_block = nn.ModuleList() + self.decode_block: nn.ModuleList[DemodDecodeBlock] = nn.ModuleList() + for i in range(self.layer_count): + outputs = min(self.maxf, startf * mul) + + has_first_conv = i != 0 + fused_scale = resolution * 2 >= 128 + block = DemodDecodeBlock(inputs, outputs, latent_size, has_first_conv, layer=i,temporal_w=temporal_w,fused_scale=fused_scale,attention=attention and attention[i],attentional_style=attentional_style,heads=heads,channels=channels) + + resolution *= 2 + self.layer_to_resolution[i] = resolution + self.decode_block.append(block) + inputs = outputs + mul //= 2 + + def decode(self, styles, lod, noise): + x = torch.randn([styles.shape[0], self.initial_inputs, self.init_temporalsamples, self.init_specchans]) + spec = None + self.std_each_scale = [] + for i in range(lod + 1): + if self.temporal_w and i!=self.layer_count-1: + w1 = F.interpolate(styles[:, 2 * i + 0],scale_factor=2**-(self.layer_count-i-1),mode='linear') + w2 = F.interpolate(styles[:, 2 * i + 1],scale_factor=2**-(self.layer_count-i-1),mode='linear') + else: + w1 = styles[:, 2 * i + 0] + w2 = styles[:, 2 * i + 1] + spec, x = self.decode_block[i](x, spec, w1, noise) + self.std_each_scale.append(spec.std()) + self.std_each_scale = torch.stack(self.std_each_scale) + self.std_each_scale/=self.std_each_scale.sum() + return spec + + def forward(self, styles, lod, blend, noise): + if blend == 1: + return self.decode(styles, lod, noise) + else: + return self.decode2(styles, lod, blend, noise) + + +@GENERATORS.register("GeneratorFormant") +class FormantSysth(nn.Module): + def __init__(self, n_mels=64, k=30): + super(FormantSysth, self).__init__() + self.n_mels = n_mels + self.k = k + self.timbre = Parameter(torch.Tensor(1,1,n_mels)) + self.silient = -1 + with torch.no_grad(): + nn.init.constant_(self.timbre,1.0) + # nn.init.constant_(self.silient,0.0) + + def formant_mask(self,freq,bandwith,amplitude): + # freq, bandwith, amplitude: B*formants*time + freq_cord = torch.arange(self.n_mels) + time_cord = torch.arange(freq.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + freq = freq.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + bandwith = bandwith.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + amplitude = amplitude.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + masks = amplitude*torch.exp(-(grid_freq-freq)**2/(2*bandwith**2)) #B,time,freqchans, formants + masks = masks.unsqueeze(dim=1) #B,1,time,freqchans + return masks + + def mel_scale(self,hz): + return (torch.log2(hz/440)+31/24)*24*self.n_mels/126 + + def inverse_mel_scale(self,mel): + return 440*2**(mel*126/24-31/24) + + def voicing(self,f0): + #f0: B*1*time + freq_cord = torch.arange(self.n_mels) + time_cord = torch.arange(f0.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + f0 = f0.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, 1 + f0 = f0.repeat([1,1,1,self.k]) #B,time,1, self.k + f0 = f0*(torch.arange(self.k)+1).reshape([1,1,1,self.k]) + bandwith = 24.7*(f0*4.37/1000+1) + bandwith_lower = torch.clamp(f0-bandwith/2,min=0.001) + bandwith_upper = f0+bandwith/2 + bandwith = self.mel_scale(bandwith_upper) - self.mel_scale(bandwith_lower) + f0 = self.mel_scale(f0) + # hamonics = torch.exp(-(grid_freq-f0)**2/(2*bandwith**2)) #gaussian + hamonics = (1-((grid_freq-f0)/(3*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(3*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = torch.cos(np.pi*torch.abs(grid_freq-f0)/(4*bandwith))**2*(-torch.sign(torch.abs(grid_freq-f0)/(4*bandwith)-0.5)*0.5+0.5) #hanning + hamonics = (hamonics.sum(dim=-1)*self.timbre).unsqueeze(dim=1) # B,1,T,F + return hamonics + + def unvoicing(self,f0): + return torch.ones([f0.shape[0],1,f0.shape[2],self.n_mels]) + + def forward(self,f0,loudness,amplitudes,freq_formants,bandwidth_formants,amplitude_formants): + # f0: B*1*T, amplitudes: B*2(voicing,unvoicing)*T, freq_formants,bandwidth_formants,amplitude_formants: B*formants*T + amplitudes = amplitudes.unsqueeze(dim=-1) + loudness = loudness.unsqueeze(dim=-1) + f0_hz = self.inverse_mel_scale(f0) + self.hamonics = self.voicing(f0_hz) + self.noise = self.unvoicing(f0_hz) + freq_formants = freq_formants*self.n_mels + bandwidth_formants = bandwidth_formants*self.n_mels + # excitation = amplitudes[:,0:1]*hamonics + # excitation = loudness*(amplitudes[:,0:1]*hamonics) + self.excitation = loudness*(amplitudes[:,0:1]*self.hamonics + amplitudes[:,-1:]*self.noise) + self.mask = self.formant_mask(freq_formants,bandwidth_formants,amplitude_formants) + self.mask_sum = self.mask.sum(dim=-1) + speech = self.excitation*self.mask_sum + self.silient*torch.ones(self.mask_sum.shape) + return speech + @GENERATORS.register("GeneratorDefault") class Generator(nn.Module): - def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, channels=3): + def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, channels=3, temporal_samples=128,spec_chans=128,temporal_w=False,global_w=True,temporal_global_cat = False,init_zeros=False,residual=False,attention=None,attentional_style=False,heads=1): super(Generator, self).__init__() self.maxf = maxf self.startf = startf @@ -685,12 +1710,22 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, channels self.channels = channels self.latent_size = latent_size - + self.temporal_w = temporal_w + self.global_w=global_w + self.temporal_global_cat = temporal_global_cat and (temporal_w and global_w) + self.init_zeros = init_zeros + self.attention = attention + self.attentional_style = attentional_style mul = 2 ** (self.layer_count - 1) inputs = min(self.maxf, startf * mul) - self.const = Parameter(torch.Tensor(1, inputs, 4, 4)) - init.ones_(self.const) + init_specchans = spec_chans//2**(self.layer_count-1) + init_temporalsamples = temporal_samples//2**(self.layer_count-1) + self.const = Parameter(torch.Tensor(1, inputs, init_temporalsamples, init_specchans)) + if init_zeros: + init.zeros_(self.const) + else: + init.ones_(self.const) self.layer_to_resolution = [0 for _ in range(layer_count)] resolution = 2 @@ -698,7 +1733,7 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, channels self.style_sizes = [] to_rgb = nn.ModuleList() - + self.attention_block = nn.ModuleList() self.decode_block: nn.ModuleList[DecodeBlock] = nn.ModuleList() for i in range(self.layer_count): outputs = min(self.maxf, startf * mul) @@ -706,46 +1741,113 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, channels has_first_conv = i != 0 fused_scale = resolution * 2 >= 128 - block = DecodeBlock(inputs, outputs, latent_size, has_first_conv, fused_scale=fused_scale, layer=i) + block = DecodeBlock(inputs, outputs, latent_size, has_first_conv, fused_scale=fused_scale, layer=i,temporal_w=temporal_w, global_w=global_w, temporal_global_cat=temporal_global_cat, residual=residual,resample=True) resolution *= 2 self.layer_to_resolution[i] = resolution self.style_sizes += [2 * (inputs if has_first_conv else outputs), 2 * outputs] - to_rgb.append(ToRGB(outputs, channels)) + to_rgb.append(ToRGB(outputs, channels,residual=residual)) #print("decode_block%d %s styles in: %dl out resolution: %d" % ( # (i + 1), millify(count_parameters(block)), outputs, resolution)) + apply_attention = attention and attention[i] + non_local = Attention(outputs,temporal_w=temporal_w, global_w=global_w, temporal_global_cat=temporal_global_cat, attentional_style=attentional_style,decoding=True,latent_size=latent_size,heads=heads) if apply_attention else None self.decode_block.append(block) + self.attention_block.append(non_local) inputs = outputs mul //= 2 self.to_rgb = to_rgb def decode(self, styles, lod, noise): + if self.temporal_w and self.global_w: + styles,styles_global = styles x = self.const - + self.std_each_scale = [] for i in range(lod + 1): - x = self.decode_block[i](x, styles[:, 2 * i + 0], styles[:, 2 * i + 1], noise) + if self.temporal_w and i!=self.layer_count-1: + w1 = F.interpolate(styles[:, 2 * i + 0],scale_factor=2**-(self.layer_count-i-1),mode='linear') + w2 = F.interpolate(styles[:, 2 * i + 1],scale_factor=2**-(self.layer_count-i-1),mode='linear') + # if self.temporal_w and self.global_w: + # w1_global = F.interpolate(styles_global[:, 2 * i + 0],scale_factor=2**-(self.layer_count-i-1),mode='linear') + # w2_global = F.interpolate(styles_global[:, 2 * i + 1],scale_factor=2**-(self.layer_count-i-1),mode='linear') + else: + w1 = styles[:, 2 * i + 0] + w2 = styles[:, 2 * i + 1] + if self.temporal_w and self.global_w: + w1_global = styles_global[:, 2 * i + 0] + w2_global = styles_global[:, 2 * i + 1] + x = self.decode_block[i](x, w1, w2, noise, w1_global, w2_global) + else: + x = self.decode_block[i](x, w1, w2, noise) + if self.attention_block[i]: + if self.temporal_w and self.global_w: + x = self.attention_block[i](x,w2, w2_global) if self.attentional_style else self.attention_block[i](x) + else: + x = self.attention_block[i](x,w2) if self.attentional_style else self.attention_block[i](x) + self.std_each_scale.append(x.std()) + self.std_each_scale = torch.stack(self.std_each_scale) + self.std_each_scale/=self.std_each_scale.sum() x = self.to_rgb[lod](x) return x def decode2(self, styles, lod, blend, noise): + if self.temporal_w and self.global_w: + styles,styles_global = styles x = self.const for i in range(lod): - x = self.decode_block[i](x, styles[:, 2 * i + 0], styles[:, 2 * i + 1], noise) - + if self.temporal_w and i!=self.layer_count-1: + w1 = F.interpolate(styles[:, 2 * i + 0],scale_factor=2**-(self.layer_count-i-1),mode='linear') + w2 = F.interpolate(styles[:, 2 * i + 1],scale_factor=2**-(self.layer_count-i-1),mode='linear') + # if self.temporal_w and self.global_w: + # w1_global = F.interpolate(styles_global[:, 2 * i + 0],scale_factor=2**-(self.layer_count-i-1),mode='linear') + # w2_global = F.interpolate(styles_global[:, 2 * i + 1],scale_factor=2**-(self.layer_count-i-1),mode='linear') + else: + w1 = styles[:, 2 * i + 0] + w2 = styles[:, 2 * i + 1] + if self.temporal_w and self.global_w: + w1_global = styles_global[:, 2 * i + 0] + w2_global = styles_global[:, 2 * i + 1] + x = self.decode_block[i](x, w1, w2, noise, w1_global, w2_global) + else: + x = self.decode_block[i](x, w1, w2, noise) + if self.attention_block[i]: + if self.temporal_w and self.global_w: + x = self.attention_block[i](x,w2,w2_global) if self.attentional_style else self.attention_block[i](x) + else: + x = self.attention_block[i](x,w2) if self.attentional_style else self.attention_block[i](x) x_prev = self.to_rgb[lod - 1](x) - x = self.decode_block[lod](x, styles[:, 2 * lod + 0], styles[:, 2 * lod + 1], noise) + if self.temporal_w and lod!=self.layer_count-1: + w1 = F.interpolate(styles[:, 2 * lod + 0],scale_factor=2**-(self.layer_count-lod-1),mode='linear') + w2 = F.interpolate(styles[:, 2 * lod + 1],scale_factor=2**-(self.layer_count-lod-1),mode='linear') + if self.temporal_w and self.global_w: + w1_global = F.interpolate(styles_global[:, 2 * lod + 0],scale_factor=2**-(self.layer_count-lod-1),mode='linear') + w2_global = F.interpolate(styles_global[:, 2 * lod + 1],scale_factor=2**-(self.layer_count-lod-1),mode='linear') + else: + w1 = styles[:, 2 * lod + 0] + w2 = styles[:, 2 * lod + 1] + if self.temporal_w and self.global_w: + w1_global = styles_global[:, 2 * lod + 0] + w2_global = styles_global[:, 2 * lod + 1] + if self.temporal_w and self.global_w: + x = self.decode_block[lod](x, w1, w2, noise, w1_global, w2_global) + else: + x = self.decode_block[lod](x, w1, w2, noise) + if self.attention_block[lod]: + if self.temporal_w and self.global_w: + x = self.attention_block[lod](x,w2,w2_global) if self.attentional_style else self.attention_block[lod](x) + else: + x = self.attention_block[lod](x,w2) if self.attentional_style else self.attention_block[lod](x) x = self.to_rgb[lod](x) needed_resolution = self.layer_to_resolution[lod] - x_prev = F.interpolate(x_prev, size=needed_resolution) + x_prev = F.interpolate(x_prev, scale_factor = 2.0) x = torch.lerp(x_prev, x, blend) return x @@ -850,12 +1952,18 @@ def forward(self, x): class MappingBlock(nn.Module): - def __init__(self, inputs, output, lrmul): + def __init__(self, inputs, output, stride =1,lrmul=0.1,temporal_w=False,transpose=False,transform_kernel=False,use_sn=False): super(MappingBlock, self).__init__() - self.fc = ln.Linear(inputs, output, lrmul=lrmul) + if temporal_w: + if transpose: + self.map = sn(ln.ConvTranspose1d(inputs, output, 3,stride,1,0,lrmul=lrmul,transform_kernel=transform_kernel),use_sn=use_sn) + else: + self.map = sn(ln.Conv1d(inputs, output, 3,stride,1,lrmul=lrmul,transform_kernel=transform_kernel),use_sn=use_sn) + else: + self.map = sn(ln.Linear(inputs, output, lrmul=lrmul),use_sn=use_sn) def forward(self, x): - x = F.leaky_relu(self.fc(x), 0.2) + x = F.leaky_relu(self.map(x), 0.2) return x @@ -884,24 +1992,64 @@ def forward(self, z): @MAPPINGS.register("MappingToLatent") class VAEMappingToLatent_old(nn.Module): - def __init__(self, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256): + def __init__(self, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256,temporal_w=False, global_w=True): super(VAEMappingToLatent_old, self).__init__() + self.temporal_w = temporal_w + self.global_w = global_w + inputs = 2* latent_size if (temporal_w and global_w) else latent_size + self.mapping_layers = mapping_layers + self.map_blocks: nn.ModuleList[MappingBlock] = nn.ModuleList() + for i in range(mapping_layers): + if not temporal_w: + outputs = 2 * dlatent_size if i == mapping_layers - 1 else mapping_fmaps + else: + outputs = mapping_fmaps + block = MappingBlock(inputs, outputs, stride = 2 if i!=0 else 1,lrmul=0.1,temporal_w=temporal_w,transform_kernel=True if i!=0 else False) + inputs = outputs + self.map_blocks.append(block) + #print("dense %d %s" % ((i + 1), millify(count_parameters(block)))) + if temporal_w: + self.Linear = sn(ln.Linear(inputs*8,2 * dlatent_size,lrmul=0.1)) + def forward(self, x, x_global=None): + # if x.dim()==3: + # x = torch.mean(x,dim=2) + if (self.temporal_w and self.global_w): + x = torch.cat((x,x_global.unsqueeze(2).repeat(1,1,x.shape[2])),dim=1) + for i in range(self.mapping_layers): + x = self.map_blocks[i](x) + if self.temporal_w: + x = x.view(x.shape[0],x.shape[1]*x.shape[2]) + x = self.Linear(x) + return x.view(x.shape[0], 2, x.shape[1] // 2) + +@MAPPINGS.register("MappingToWord") +class MappingToWord(nn.Module): + def __init__(self, mapping_layers=5, latent_size=256, uniq_words=256, mapping_fmaps=256,temporal_w=False): + super(MappingToWord, self).__init__() + self.temporal_w = temporal_w inputs = latent_size self.mapping_layers = mapping_layers self.map_blocks: nn.ModuleList[MappingBlock] = nn.ModuleList() for i in range(mapping_layers): - outputs = 2 * dlatent_size if i == mapping_layers - 1 else mapping_fmaps - block = ln.Linear(inputs, outputs, lrmul=0.1) + if not temporal_w: + outputs = uniq_words if i == mapping_layers - 1 else mapping_fmaps + else: + outputs = mapping_fmaps + block = MappingBlock(inputs, outputs , stride = 2 if i!=0 else 1,lrmul=0.1,temporal_w=temporal_w,transform_kernel=True if i!=0 else False) inputs = outputs self.map_blocks.append(block) #print("dense %d %s" % ((i + 1), millify(count_parameters(block)))) - + if temporal_w: + self.Linear = sn(ln.Linear(inputs*8,uniq_words,lrmul=0.1)) def forward(self, x): + if x.dim()==3: + x = torch.mean(x,dim=2) for i in range(self.mapping_layers): x = self.map_blocks[i](x) - - return x.view(x.shape[0], 2, x.shape[2] // 2) - + if self.temporal_w: + x = x.view(x.shape[0],x.shape[1]*x.shape[2]) + x = self.Linear(x) + return x @MAPPINGS.register("MappingToLatentNoStyle") class VAEMappingToLatentNoStyle(nn.Module): @@ -929,26 +2077,55 @@ def forward(self, x): @MAPPINGS.register("MappingFromLatent") class VAEMappingFromLatent(nn.Module): - def __init__(self, num_layers, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256): + def __init__(self, num_layers, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256,temporal_w=False,global_w = True): super(VAEMappingFromLatent, self).__init__() - inputs = dlatent_size self.mapping_layers = mapping_layers self.num_layers = num_layers + self.temporal_w = temporal_w + self.global_w = global_w + self.latent_size = latent_size self.map_blocks: nn.ModuleList[MappingBlock] = nn.ModuleList() + if temporal_w and global_w: + self.map_blocks_global: nn.ModuleList[MappingBlock] = nn.ModuleList() + if temporal_w: + self.Linear = sn(ln.Linear(dlatent_size,8*(latent_size//8)),use_sn=PARTIAL_SN) + inputs = latent_size//8 + else: + inputs = dlatent_size for i in range(mapping_layers): outputs = latent_size if i == mapping_layers - 1 else mapping_fmaps - block = MappingBlock(inputs, outputs, lrmul=0.1) + block = MappingBlock(inputs, outputs, stride = i%2+1, lrmul=0.1,temporal_w=temporal_w,transform_kernel=True if i%2==1 else False, transpose=True,use_sn=PARTIAL_SN) inputs = outputs self.map_blocks.append(block) #print("dense %d %s" % ((i + 1), millify(count_parameters(block)))) - - def forward(self, x): + + if temporal_w and global_w: + inputs = dlatent_size + for i in range(mapping_layers): + outputs = latent_size if i == mapping_layers - 1 else mapping_fmaps + block_global = MappingBlock(inputs, outputs, stride = i%2+1, lrmul=0.1,temporal_w=False,transform_kernel=True if i%2==1 else False, transpose=True,use_sn=PARTIAL_SN) + inputs = outputs + self.map_blocks_global.append(block_global) + + def forward(self, x,x_global=None): x = pixel_norm(x) - + if self.temporal_w: + x = self.Linear(x) + x = F.leaky_relu(x,0.2) + x = x.view(x.shape[0],self.latent_size//8,8) for i in range(self.mapping_layers): x = self.map_blocks[i](x) - - return x.view(x.shape[0], 1, x.shape[1]).repeat(1, self.num_layers, 1) + + if self.temporal_w and self.global_w: + x_global = pixel_norm(x_global) + for i in range(self.mapping_layers): + x_global = self.map_blocks_global[i](x_global) + return x.view(x.shape[0], 1, x.shape[1],x.shape[2]).repeat(1, self.num_layers, 1,1), x_global.view(x_global.shape[0], 1, x_global.shape[1]).repeat(1, self.num_layers, 1) + else: + if self.temporal_w: + return x.view(x.shape[0], 1, x.shape[1],x.shape[2]).repeat(1, self.num_layers, 1,1) + else: + return x.view(x.shape[0], 1, x.shape[1]).repeat(1, self.num_layers, 1) @ENCODERS.register("EncoderFC") @@ -1015,3 +2192,5 @@ def decode(self, x, lod, blend_factor, noise): def forward(self, x, lod, blend_factor, noise): return self.decode(x, lod, blend_factor, noise) + + diff --git a/net_formant.py b/net_formant.py new file mode 100644 index 00000000..b8fed7c6 --- /dev/null +++ b/net_formant.py @@ -0,0 +1,1597 @@ +import os +import pdb +from random import triangular +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F +from torch.nn import Parameter as P +from torch.nn import init +from torch.nn.parameter import Parameter +import numpy as np +import lreq as ln +import math +from registry import * +#from transformer_models.position_encoding import build_position_encoding +#from transformer_models.transformer import Transformer as TransformerTS +# +# from transformer_models.transformer_nonlocal import Transformer as TransformerNL + +def db(x,noise = -80, slope =35, powerdb=True): + if powerdb: + return ((2*torchaudio.transforms.AmplitudeToDB()(x)).clamp(min=noise)-slope-noise)/slope + else: + return ((torchaudio.transforms.AmplitudeToDB()(x)).clamp(min=noise)-slope-noise)/slope + +# def amplitude(x,noise=-80,slope=35): +# return 10**((x*slope+noise+slope)/20.) + +def amplitude(x,noise_db=-60,max_db=35,trim_noise=False): + if trim_noise: + x_db = (x+1)/2*(max_db-noise_db)+noise_db + if type(x) is np.ndarray: + return 10**(x_db/10)*(np.sign(x_db-noise_db)*0.5+0.5) + else: + return 10**(x_db/10)*((x_db-noise_db).sign()*0.5+0.5) + else: + return 10**(((x+1)/2*(max_db-noise_db)+noise_db)/10) + +def to_db(x,noise_db=-60,max_db=35): + return (torchaudio.transforms.AmplitudeToDB()(x)-noise_db)/(max_db-noise_db)*2-1 + +def wave2spec(wave,n_fft=256,wave_fr=16000,spec_fr=125,noise_db=-60,max_db=22.5,to_db=True,power=2): +# def wave2spec(wave,n_fft=256,wave_fr=16000,spec_fr=125,noise_db=-50,max_db=22.5,to_db=True): + if to_db: + return (torchaudio.transforms.AmplitudeToDB()(torchaudio.transforms.Spectrogram(n_fft*2-1,win_length=n_fft*2-1,hop_length=int(wave_fr/spec_fr),power=power)(wave)).clamp(min=noise_db,max=max_db).transpose(-2,-1)-noise_db)/(max_db-noise_db)*2-1 + else: + return torchaudio.transforms.Spectrogram(n_fft*2-1,win_length=n_fft*2-1,hop_length=int(wave_fr/spec_fr),power=power)(wave).transpose(-2,-1) + + +# def mel_scale(n_mels,hz,min_octave=-31.,max_octave=95.,pt=True): +# def mel_scale(n_mels,hz,min_octave=-58.,max_octave=100.,pt=True): +def mel_scale(n_mels,hz,min_octave=-31.,max_octave=102.,pt=True): + #take absolute hz, return abs mel + # return (torch.log2(hz/440)+31/24)*24*n_mels/126 + if pt: + return (torch.log2(hz/440.)-min_octave/24.)*24*n_mels/(max_octave-min_octave) + else: + return (np.log2(hz/440.)-min_octave/24.)*24*n_mels/(max_octave-min_octave) + +# def inverse_mel_scale(mel,min_octave=-31.,max_octave=95.): +# def inverse_mel_scale(mel,min_octave=-58.,max_octave=100.): +def inverse_mel_scale(mel,min_octave=-31.,max_octave=102.): + #take normalized mel, return absolute hz + # return 440*2**(mel*126/24-31/24) + return 440*2**(mel*(max_octave-min_octave)/24.+min_octave/24.) + +# def mel_scale(n_mels,hz,f_min=160.,f_max=8000.,pt=True): +# #take absolute hz, return abs mel +# # return (torch.log2(hz/440)+31/24)*24*n_mels/126 +# m_min = 2595.0 * np.log10(1.0 + (f_min / 700.0)) +# m_max = 2595.0 * np.log10(1.0 + (f_max / 700.0)) +# m_min_ = m_min + (m_max-m_min)/(n_mels+1) +# m_max_ = m_max +# if pt: +# return (2595.0 * torch.log10(1.0 + (hz / 700.0))-m_min_)/(m_max_-m_min_)*n_mels +# else: +# return (2595.0 * np.log10(1.0 + (hz / 700.0))-m_min_)/(m_max_-m_min_)*n_mels + +# def inverse_mel_scale(mel,f_min=160.,f_max=8000.,n_mels=64): +# #take normalized mel, return absolute hz +# # return 440*2**(mel*126/24-31/24) +# m_min = 2595.0 * np.log10(1.0 + (f_min / 700.0)) +# m_max = 2595.0 * np.log10(1.0 + (f_max / 700.0)) +# m_min_ = m_min + (m_max-m_min)/(n_mels+1) +# m_max_ = m_max +# return 700.0 * (10**((mel*(m_max_-m_min_) + m_min_)/ 2595.0) - 1.0) + +def ind2hz(ind,n_fft,max_freq=8000.): + #input abs ind, output abs hz + return ind/(1.0*n_fft)*max_freq + +def hz2ind(hz,n_fft,max_freq=8000.): + # input abs hz, output abs ind + return hz/(1.0*max_freq)*n_fft + +def bandwidth_mel(freqs_hz,bandwidth_hz,n_mels): + # input hz bandwidth, output abs bandwidth on mel + bandwidth_upper = freqs_hz+bandwidth_hz/2. + bandwidth_lower = torch.clamp(freqs_hz-bandwidth_hz/2.,min=1) + bandwidth = mel_scale(n_mels,bandwidth_upper) - mel_scale(n_mels,bandwidth_lower) + return bandwidth + +def torch_P2R(radii, angles): + return radii * torch.cos(angles),radii * torch.sin(angles) +def inverse_spec_to_audio(spec,n_fft = 511,win_length = 511,hop_length = 128,power_synth=True): + ''' + generate random phase, then use istft to inverse spec to audio + ''' + window = torch.hann_window(win_length) + angles = torch.randn_like(spec).uniform_(0, np.pi*2)#torch.zeros_like(spec)#torch.randn_like(spec).uniform_(0, np.pi*2) + spec = spec**0.5 if power_synth else spec + spec_complex = torch.stack(torch_P2R(spec, angles),dim=-1) #real and image in same dim + return torchaudio.functional.istft(spec_complex, n_fft=n_fft, window=window, center=True, win_length=win_length, hop_length=hop_length) + +@GENERATORS.register("GeneratorFormant") +class FormantSysth(nn.Module): + def __init__(self, n_mels=64, k=100, wavebased=False,n_fft=256,noise_db=-50,max_db=22.5,dbbased=False,add_bgnoise=True,log10=False,noise_from_data=False,return_wave=False,power_synth=False): + super(FormantSysth, self).__init__() + self.wave_fr = 16e3 + self.spec_fr = 125 + self.n_fft = n_fft + self.noise_db=noise_db + self.max_db = max_db + self.n_mels = n_mels + self.k = k + self.dbbased=dbbased + self.log10 = log10 + self.add_bgnoise = add_bgnoise + self.wavebased=wavebased + self.noise_from_data = noise_from_data + self.linear_scale = wavebased + self.return_wave = return_wave + self.power_synth = power_synth + self.timbre = Parameter(torch.Tensor(1,1,n_mels)) + self.timbre_mapping = nn.Sequential( + ln.Conv1d(1,128,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,2,1), + # nn.Sigmoid(), + ) + self.bgnoise_mapping = nn.Sequential( + ln.Conv2d(2,2,[1,5],padding=[0,2],gain=1,bias=False), + # nn.Sigmoid(), + ) + self.noise_mapping = nn.Sequential( + ln.Conv2d(2,2,[1,5],padding=[0,2],gain=1,bias=False), + # nn.Sigmoid(), + ) + + self.bgnoise_mapping2 = nn.Sequential( + ln.Conv1d(1,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,gain=1,bias=False), + # nn.Sigmoid(), + ) + self.noise_mapping2 = nn.Sequential( + ln.Conv1d(1,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,1,gain=1,bias=False), + # nn.Sigmoid(), + ) + self.prior_exp = np.array([0.4963,0.0745,1.9018]) + self.timbre_parameter = Parameter(torch.Tensor(2)) + self.wave_noise_amplifier = Parameter(torch.Tensor(1)) + self.wave_hamon_amplifier = Parameter(torch.Tensor(1)) + + if noise_from_data: + self.bgnoise_amp = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.bgnoise_amp,1) + else: + self.bgnoise_dist = Parameter(torch.Tensor(1,1,1,self.n_fft if self.wavebased else self.n_mels)) + with torch.no_grad(): + nn.init.constant_(self.bgnoise_dist,1.0) + # self.silient = Parameter(torch.Tensor(1,1,n_mels)) + self.silient = -1 + with torch.no_grad(): + nn.init.constant_(self.timbre,1.0) + nn.init.constant_(self.timbre_parameter[0],7) + nn.init.constant_(self.timbre_parameter[1],0.004) + nn.init.constant_(self.wave_noise_amplifier,1) + nn.init.constant_(self.wave_hamon_amplifier,4.) + + # nn.init.constant_(self.silient,-1.0) + +# def formant_mask(self,freq,bandwith,amplitude): +# # freq, bandwith, amplitude: B*formants*time +# freq_cord = torch.arange(self.n_mels) +# time_cord = torch.arange(freq.shape[2]) +# grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) +# grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 +# grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 +# freq = freq.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# bandwith = bandwith.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# amplitude = amplitude.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# # masks = amplitude*torch.exp(-0.693*(grid_freq-freq)**2/(2*(bandwith+0.001)**2)) #B,time,freqchans, formants +# masks = amplitude*torch.exp(-(grid_freq-freq)**2/(2*(bandwith/np.sqrt(2*np.log(2))+0.001)**2)) #B,time,freqchans, formants +# masks = masks.unsqueeze(dim=1) #B,1,time,freqchans, formants +# return masks + + def formant_mask(self,freq_hz,bandwith_hz,amplitude,linear=False, triangle_mask = False,duomask=True, n_formant_noise=1,f0_hz=None): + # freq, bandwith, amplitude: B*formants*time + freq_cord = torch.arange(self.n_fft if linear else self.n_mels) + time_cord = torch.arange(freq_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq_hz = ind2hz(grid_freq,self.n_fft,self.wave_fr/2) if linear else inverse_mel_scale(grid_freq/(self.n_mels*1.0)) + freq_hz = freq_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + bandwith_hz = bandwith_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + amplitude = amplitude.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + if self.power_synth: + amplitude = amplitude + alpha = (2*np.sqrt(2*np.log(np.sqrt(2)))) + if self.return_wave: + t = torch.arange(int(f0_hz.shape[2]/self.spec_fr*self.wave_fr))/(1.0*self.wave_fr) #in second + t = t.unsqueeze(dim=0).unsqueeze(dim=0) #1, 1, time + k = (torch.arange(self.k)+1).reshape([1,self.k,1]) + # f0_hz_interp = F.interpolate(f0_hz,t.shape[-1],mode='linear',align_corners=False) #Bx1xT + # bandwith_hz_interp = F.interpolate(bandwith_hz.permute(0,2,3,1),[bandwith_hz.shape[-1],t.shape[-1]],mode='bilinear',align_corners=False).permute(0,3,1,2) #Bx1xT + # freq_hz_interp = F.interpolate(freq_hz.permute(0,2,3,1),[freq_hz.shape[-1],t.shape[-1]],mode='bilinear',align_corners=False).permute(0,3,1,2) #Bx1xT + k_f0 = k*f0_hz #BxkxT + k_f0 = k_f0.permute([0,2,1]).unsqueeze(-1) #BxTxkx1 + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) if self.wavebased else amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) #B,time,freqchans, formants + # amplitude_interp = F.interpolate(amplitude.permute(0,2,3,1),[amplitude.shape[-1],t.shape[-1]],mode='bilinear',align_corners=False).permute(0,3,1,2) #Bx1xT + hamonic_dist = (amplitude*torch.exp(-((k_f0-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2))).sqrt().sum(-1).permute([0,2,1]) #BxkxT + # hamonic_dist = (amplitude*torch.exp(-((k_f0-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2))).sum(-1).permute([0,2,1]) #BxkxT + hamonic_dist = F.interpolate(hamonic_dist,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr),mode = 'linear',align_corners=False) + # if self.wavebased: + if triangle_mask: + if duomask: + # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-(0.693*(grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]+0.01)**2)) + masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-((grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]/alpha+0.01)**2)) + bw = bandwith_hz[...,-n_formant_noise:] + masks_noise = F.relu(amplitude[...,-n_formant_noise:] * (1 - (1-1/np.sqrt(2))*2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())) + # masks_noise = amplitude[...,-n_formant_noise:] * (1 - 2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz[...,-n_formant_noise:])/(bw+0.01)-0.5)*0.5+0.5) + masks = torch.cat([masks_hamon,masks_noise],dim=-1) + else: + masks = F.relu(amplitude * (1 - (1-1/np.sqrt(2))*2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())) + else: + # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) + if self.power_synth: + masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + else: + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + masks = amplitude*(torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2))+1E-6).sqrt() + # else: + # if triangle_mask: + # if duomask: + # # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-(0.693*(grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]+0.01)**2)) + # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-((grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]/alpha+0.01)**2)) + # bw = bandwith_hz[...,-n_formant_noise:] + # masks_noise = F.relu(amplitude[...,-n_formant_noise:] * (1 - (1-1/np.sqrt(2))*2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())) + # # masks_noise = amplitude[...,-n_formant_noise:] * (1 - 2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz[...,-n_formant_noise:])/(bw+0.01)-0.5)*0.5+0.5) + # masks = torch.cat([masks_hamon,masks_noise],dim=-1) + # else: + # masks = F.relu(amplitude * (1 - (1-1/np.sqrt(2))*2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())) + # # masks = amplitude * (1 - 2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz)/(bandwith_hz+0.01)-0.5)*0.5+0.5) + # else: + # # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/(2*np.sqrt(2*np.log(2)))+0.01)**2)) #B,time,freqchans, formants + # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) #B,time,freqchans, formants + masks = masks.unsqueeze(dim=1) #B,1,time,freqchans, formants + if self.return_wave: + return masks, hamonic_dist#B,1,time,freqchans + else: + return masks + + def voicing_wavebased(self,f0_hz): + #f0: B*1*time, hz + t = torch.arange(int(f0_hz.shape[2]/self.spec_fr*self.wave_fr))/(1.0*self.wave_fr) #in second + t = t.unsqueeze(dim=0).unsqueeze(dim=0) #1, 1, time + k = (torch.arange(self.k)+1).reshape([1,self.k,1]) + f0_hz_interp = F.interpolate(f0_hz,t.shape[-1],mode='linear',align_corners=False) + k_f0 = k*f0_hz_interp + k_f0_sum = 2*np.pi*torch.cumsum(k_f0,-1)/(1.0*self.wave_fr) + wave_k = np.sqrt(2)*torch.sin(k_f0_sum) * (-torch.sign(k_f0-7800)*0.5+0.5) + # wave = 0.12*torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-6000)*0.5+0.5) + # wave = 0.09*torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-self.wave_fr/2)*0.5+0.5) + # wave = 0.09*torch.sigmoid(self.wave_hamon_amplifier) * torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-self.wave_fr/2)*0.5+0.5) + wave = wave_k.sum(dim=1,keepdim=True) + # wave = F.softplus(self.wave_hamon_amplifier) * wave.sum(dim=1,keepdim=True) + spec = wave2spec(wave,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=self.dbbased,power=2. if self.power_synth else 1.) + if self.return_wave: + return spec,wave_k + else: + return spec + + def unvoicing_wavebased(self,f0_hz,bg=False,mapping=True): + # return torch.ones([1,1,f0_hz.shape[2],512]) + # noise = 0.3*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.03*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + if bg: + noise = torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + if mapping: + noise = self.bgnoise_mapping2(noise) + else: + noise = np.sqrt(3.)*(2*torch.rand([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)])-1) + if mapping: + noise = self.noise_mapping2(noise) + # noise = 0.3 * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.3 * F.softplus(self.wave_noise_amplifier) * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + return wave2spec(noise,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=self.dbbased,power=2. if self.power_synth else 1.) + + # def unvoicing_wavebased(self,f0_hz): + # # return torch.ones([1,1,f0_hz.shape[2],512]) + # # noise = 0.3*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.1*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # # noise = 0.3 * torch.sigmoid(self.wave_noise_amplifier) * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # return wave2spec(noise,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=False) + + def voicing_linear(self,f0_hz,bandwith=2.5): + #f0: B*1*time, hz + freq_cord = torch.arange(self.n_fft) + time_cord = torch.arange(f0_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + f0_hz = f0_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, 1 + f0_hz = f0_hz.repeat([1,1,1,self.k]) #B,time,1, self.k + f0_hz = f0_hz*(torch.arange(self.k)+1).reshape([1,1,1,self.k]) + # bandwith=4 + # bandwith_lower = torch.clamp(f0-bandwith/2,min=1) + # bandwith_upper = f0+bandwith/2 + # bandwith = mel_scale(self.n_mels,bandwith_upper) - mel_scale(self.n_mels,bandwith_lower) + f0 = hz2ind(f0_hz,self.n_fft) + + # hamonics = torch.exp(-(grid_freq-f0)**2/(2*sigma**2)) #gaussian + # hamonics = (1-((grid_freq_hz-f0_hz)/(2*bandwith_hz/2))**2)*(-torch.sign(torch.abs(grid_freq_hz-f0_hz)/(2*bandwith_hz)-0.5)*0.5+0.5) #welch + freq_cord_reshape = freq_cord.reshape([1,1,1,self.n_fft]) + hamonics = (1 - 2/bandwith*(grid_freq-f0).abs())*(-torch.sign(torch.abs(grid_freq-f0)/(bandwith)-0.5)*0.5+0.5) #triangular + # hamonics = (1-((grid_freq-f0)/(bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(2.5*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(2.5*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(3*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(3*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = torch.cos(np.pi*torch.abs(grid_freq-f0)/(4*bandwith))**2*(-torch.sign(torch.abs(grid_freq-f0)/(4*bandwith)-0.5)*0.5+0.5) #hanning + # hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) # B,1,T,F + # condition = (torch.sign(freq_cord_reshape-switch)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-slop*(freq_cord_reshape-switch)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)-torch.abs(self.prior_exp_parameter[2])) * torch.exp(-torch.abs(self.prior_exp_parameter[1])*freq_cord.reshape([1,1,1,self.n_mels])) + torch.abs(self.prior_exp_parameter[2]) # B,1,T,F + + # timbre_parameter = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]).unsqueeze(1) + # condition = (torch.sign(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) * F.softplus(timbre_parameter[...,2:3]) + timbre_parameter[...,3:4] # B,1,T,F + # timbre = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]) + # hamonics = (hamonics.sum(dim=-1)*timbre).unsqueeze(dim=1) # B,1,T,F + # hamonics = (hamonics.sum(dim=-1)*self.timbre).unsqueeze(dim=1) # B,1,T,F + + hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) + # hamonics = 180*F.softplus(self.wave_hamon_amplifier)*(hamonics.sum(dim=-1)).unsqueeze(dim=1) + + return hamonics + + def voicing(self,f0_hz): + #f0: B*1*time, hz + freq_cord = torch.arange(self.n_mels) + time_cord = torch.arange(f0_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq_hz = inverse_mel_scale(grid_freq/(self.n_mels*1.0)) + f0_hz = f0_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, 1 + f0_hz = f0_hz.repeat([1,1,1,self.k]) #B,time,1, self.k + f0_hz = f0_hz*(torch.arange(self.k)+1).reshape([1,1,1,self.k]) + if self.log10: + f0_mel = mel_scale(self.n_mels,f0_hz) + band_low_hz = inverse_mel_scale((f0_mel-1)/(self.n_mels*1.0),n_mels = self.n_mels) + band_up_hz = inverse_mel_scale((f0_mel+1)/(self.n_mels*1.0),n_mels = self.n_mels) + bandwith_hz = band_up_hz-band_low_hz + band_low_mel = mel_scale(self.n_mels,band_low_hz) + band_up_mel = mel_scale(self.n_mels,band_up_hz) + bandwith = band_up_mel-band_low_mel + else: + bandwith_hz = 24.7*(f0_hz*4.37/1000+1) + bandwith = bandwidth_mel(f0_hz,bandwith_hz,self.n_mels) + # bandwith_lower = torch.clamp(f0-bandwith/2,min=1) + # bandwith_upper = f0+bandwith/2 + # bandwith = mel_scale(self.n_mels,bandwith_upper) - mel_scale(self.n_mels,bandwith_lower) + f0 = mel_scale(self.n_mels,f0_hz) + switch = mel_scale(self.n_mels,torch.abs(self.timbre_parameter[0])*f0_hz[...,0]).unsqueeze(1) + slop = (torch.abs(self.timbre_parameter[1])*f0_hz[...,0]).unsqueeze(1) + freq_cord_reshape = freq_cord.reshape([1,1,1,self.n_mels]) + if not self.dbbased: + # sigma = bandwith/(np.sqrt(2*np.log(2))); + sigma = bandwith/(2*np.sqrt(2*np.log(2))); + hamonics = torch.exp(-(grid_freq-f0)**2/(2*sigma**2)) #gaussian + # hamonics = (1-((grid_freq_hz-f0_hz)/(2*bandwith_hz/2))**2)*(-torch.sign(torch.abs(grid_freq_hz-f0_hz)/(2*bandwith_hz)-0.5)*0.5+0.5) #welch + else: + # # hamonics = (1-((grid_freq-f0)/(1.75*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(1.75*bandwith)-0.5)*0.5+0.5) #welch + hamonics = (1-((grid_freq-f0)/(2.5*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(2.5*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(3*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(3*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = torch.cos(np.pi*torch.abs(grid_freq-f0)/(4*bandwith))**2*(-torch.sign(torch.abs(grid_freq-f0)/(4*bandwith)-0.5)*0.5+0.5) #hanning + # hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) # B,1,T,F + # condition = (torch.sign(freq_cord_reshape-switch)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-slop*(freq_cord_reshape-switch)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)-torch.abs(self.prior_exp_parameter[2])) * torch.exp(-torch.abs(self.prior_exp_parameter[1])*freq_cord.reshape([1,1,1,self.n_mels])) + torch.abs(self.prior_exp_parameter[2]) # B,1,T,F + + timbre_parameter = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]).unsqueeze(1) + condition = (torch.sign(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*0.5+0.5) + amp = F.softplus(self.wave_hamon_amplifier) if self.dbbased else 180*F.softplus(self.wave_hamon_amplifier) + hamonics = amp * ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) * F.softplus(timbre_parameter[...,2:3]) + timbre_parameter[...,3:4] # B,1,T,F + # timbre = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]) + # hamonics = (hamonics.sum(dim=-1)*timbre).unsqueeze(dim=1) # B,1,T,F + # hamonics = (hamonics.sum(dim=-1)*self.timbre).unsqueeze(dim=1) # B,1,T,F + # return F.softplus(self.wave_hamon_amplifier)*hamonics + return hamonics + + def unvoicing(self,f0,bg=False,mapping=True): + # return (0.25*torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels]))+1 + rnd = torch.randn([f0.shape[0],2,f0.shape[2],self.n_fft if self.wavebased else self.n_mels]) + if mapping: + rnd = self.bgnoise_mapping(rnd) if bg else self.noise_mapping(rnd) + real = rnd[:,0:1] + img = rnd[:,1:2] + if self.dbbased: + return (2*torchaudio.transforms.AmplitudeToDB()(torch.sqrt(real**2 + img**2+1E-10))+80).clamp(min=0)/35 + # return (2*torchaudio.transforms.AmplitudeToDB()(F.softplus(self.wave_noise_amplifier)*torch.sqrt(torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2 + torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2))+80).clamp(min=0)/35 + else: + # return torch.ones([f0.shape[0],1,f0.shape[2],self.n_mels]) + return 180*F.softplus(self.wave_noise_amplifier) * torch.sqrt(real**2 + img**2+1E-10) + # return F.softplus(self.wave_noise_amplifier)*torch.sqrt(torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2 + torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2) + + # return (F.softplus(self.wave_noise_amplifier)) * (0.25*torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels]))+1 + # return torch.ones([f0.shape[0],1,f0.shape[2],self.n_mels]) + + def forward(self,components,enable_hamon_excitation=True,enable_noise_excitation=True,enable_bgnoise=True): + # f0: B*1*T, amplitudes: B*2(voicing,unvoicing)*T, freq_formants,bandwidth_formants,amplitude_formants: B*formants*T + amplitudes = components['amplitudes'].unsqueeze(dim=-1) + amplitudes_h = components['amplitudes_h'].unsqueeze(dim=-1) + loudness = components['loudness'].unsqueeze(dim=-1) + f0_hz = components['f0_hz'] + # import pdb;pdb.set_trace() + if self.wavebased: + # self.hamonics = 1800*F.softplus(self.wave_hamon_amplifier)*self.voicing_linear(f0_hz) + # self.noise = 180*self.unvoicing(f0_hz,bg=False,mapping=False) + # self.bgnoise = 18*self.unvoicing(f0_hz,bg=True,mapping=False) + # import pdb;pdb.set_trace() + self.hamonics = self.voicing_wavebased(f0_hz) + self.noise = self.unvoicing_wavebased(f0_hz,bg=False,mapping=False) + self.bgnoise = self.unvoicing_wavebased(f0_hz,bg=True) + else: + self.hamonics = self.voicing(f0_hz) + self.noise = self.unvoicing(f0_hz,bg=False) + self.bgnoise = self.unvoicing(f0_hz,bg=True) + # freq_formants = components['freq_formants']*self.n_mels + # bandwidth_formants = components['bandwidth_formants']*self.n_mels + # excitation = amplitudes[:,0:1]*hamonics + # excitation = loudness*(amplitudes[:,0:1]*hamonics) + # self.noise = self.noise + self.excitation_noise = loudness*(amplitudes[:,-1:])*self.noise if self.power_synth else loudness*amplitudes[:,-1:]*self.noise + duomask = components['freq_formants_noise_hz'].shape[1]>components['freq_formants_hamon_hz'].shape[1] + n_formant_noise = (components['freq_formants_noise_hz'].shape[1]-components['freq_formants_hamon_hz'].shape[1]) if duomask else components['freq_formants_noise_hz'].shape[1] + self.mask_hamon = self.formant_mask(components['freq_formants_hamon_hz'],components['bandwidth_formants_hamon_hz'],components['amplitude_formants_hamon'],linear = self.linear_scale,f0_hz = f0_hz) + self.mask_noise = self.formant_mask(components['freq_formants_noise_hz'],components['bandwidth_formants_noise_hz'],components['amplitude_formants_noise'],linear = self.linear_scale,triangle_mask=False if self.wavebased else True,duomask=duomask,n_formant_noise=n_formant_noise,f0_hz = f0_hz) + # self.mask_hamon = self.formant_mask(components['freq_formants_hamon']*self.n_mels,components['bandwidth_formants_hamon'],components['amplitude_formants_hamon']) + # self.mask_noise = self.formant_mask(components['freq_formants_noise']*self.n_mels,components['bandwidth_formants_noise'],components['amplitude_formants_noise']) + if self.return_wave: + self.hamonics,self.hamonics_wave = self.hamonics + self.mask_hamon, self.hamonic_dist = self.mask_hamon + self.mask_noise, self.mask_noise_only = self.mask_noise + if self.power_synth: + self.excitation_hamon_wave = F.interpolate((loudness[...,-1]*amplitudes[:,0:1][...,-1]).sqrt(),self.hamonics_wave.shape[-1],mode='linear',align_corners=False)*self.hamonics_wave + else: + self.excitation_hamon_wave = F.interpolate(loudness[...,-1]*amplitudes[:,0:1][...,-1],self.hamonics_wave.shape[-1],mode='linear',align_corners=False)*self.hamonics_wave + self.hamonics_wave_ = (self.excitation_hamon_wave*self.hamonic_dist).sum(1,keepdim=True) + + self.mask_hamon_sum = self.mask_hamon.sum(dim=-1) + self.mask_noise_sum = self.mask_noise.sum(dim=-1) + bgdist = F.softplus(self.bgnoise_amp)*self.noise_dist if self.noise_from_data else F.softplus(self.bgnoise_dist) + if self.power_synth: + self.excitation_hamon = loudness*(amplitudes[:,0:1])*self.hamonics + else: + self.excitation_hamon = loudness*amplitudes[:,0:1]*self.hamonics + # import pdb;pdb.set_trace() + self.noise_excitation = self.excitation_noise*self.mask_noise_sum + if self.return_wave: + self.noise_excitation_wave = 2*inverse_spec_to_audio(self.noise_excitation.squeeze(1).permute(0,2,1),n_fft=self.n_fft*2-1,power_synth=self.power_synth) + self.noise_excitation_wave = F.pad(self.noise_excitation_wave,[0,self.hamonics_wave_.shape[2]-self.noise_excitation_wave.shape[1]]) + self.noise_excitation_wave = self.noise_excitation_wave.unsqueeze(1) + self.rec_wave = self.noise_excitation_wave+self.hamonics_wave_ + if self.wavebased: + # import pdb; pdb.set_trace() + bgn = bgdist*self.bgnoise*0.0003 if (self.add_bgnoise and enable_bgnoise) else 0 + speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.noise_excitation if enable_noise_excitation else 0) + bgn + # speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.excitation_noise*self.mask_noise_sum if enable_noise_excitation else 0) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + # speech = speech if self.power_synth else speech**2 + speech = (torchaudio.transforms.AmplitudeToDB()(speech).clamp(min=self.noise_db)-self.noise_db)/(self.max_db-self.noise_db)*2-1 + else: + # speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + (self.silient*torch.ones(self.mask_hamon_sum.shape) if self.dbbased else 0) + speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.noise_excitation if enable_noise_excitation else 0) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + (self.silient*torch.ones(self.mask_hamon_sum.shape) if self.dbbased else 0) + # speech = self.excitation_hamon*self.mask_hamon_sum + (self.excitation_noise*self.mask_noise_sum if enable_noise_excitation else 0) + self.silient*torch.ones(self.mask_hamon_sum.shape) + if not self.dbbased: + speech = db(speech) + + + # import pdb;pdb.set_trace() + if self.return_wave: + return speech,self.rec_wave + else: + return speech + +@ENCODERS.register("EncoderFormant") +class FormantEncoder(nn.Module): + def __init__(self, n_mels=64, n_formants=4,n_formants_noise=2,min_octave=-31,max_octave=96,wavebased=False,n_fft=256,noise_db=-50,max_db=22.5,broud=True,power_synth=False,hop_length=128): + super(FormantEncoder, self).__init__() + self.wavebased = wavebased + self.n_mels = n_mels + self.n_formants = n_formants + self.n_formants_noise = n_formants_noise + self.min_octave = min_octave + self.max_octave = max_octave + self.noise_db = noise_db + self.max_db = max_db + self.broud = broud + self.n_fft = n_fft + self.power_synth=power_synth + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,700.,2700.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + # self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,500.,500.]).reshape([1,3,1]) #freq difference + + self.formant_bandwitdh_bias = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_thres = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_bias,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + nn.init.constant_(self.formant_bandwitdh_thres,0) + + # self.formant_freq_limits = torch.cumsum(self.formant_freq_limits_diff,dim=0) + # self.formant_freq_limits_mel = torch.cat([torch.tensor([0.]),mel_scale(n_mels,self.formant_freq_limits)/n_mels]) + # self.formant_freq_limits_mel_diff = torch.reshape(self.formant_freq_limits_mel[1:]-self.formant_freq_limits_mel[:-1],[1,3,1]) + if broud: + if wavebased: + self.conv1_narrow = ln.Conv1d(n_fft,64,3,1,1) + self.conv1_mel = ln.Conv1d(128,64,3,1,1) + self.norm1_mel = nn.GroupNorm(32,64) + self.conv2_mel = ln.Conv1d(64,128,3,1,1) + self.norm2_mel = nn.GroupNorm(32,128) + self.conv_fundementals_mel = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals_mel = nn.GroupNorm(32,128) + self.f0_drop_mel = nn.Dropout() + else: + self.conv1_narrow = ln.Conv1d(n_mels,64,3,1,1) + self.norm1_narrow = nn.GroupNorm(32,64) + self.conv2_narrow = ln.Conv1d(64,128,3,1,1) + self.norm2_narrow = nn.GroupNorm(32,128) + + self.conv_fundementals_narrow = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals_narrow = nn.GroupNorm(32,128) + self.f0_drop_narrow = nn.Dropout() + if wavebased: + self.conv_f0_narrow = ln.Conv1d(256,1,1,1,0) + else: + self.conv_f0_narrow = ln.Conv1d(128,1,1,1,0) + + self.conv_amplitudes_narrow = ln.Conv1d(128,2,1,1,0) + self.conv_amplitudes_h_narrow = ln.Conv1d(128,2,1,1,0) + + if wavebased: + self.conv1 = ln.Conv1d(n_fft,64,3,1,1) + else: + self.conv1 = ln.Conv1d(n_mels,64,3,1,1) + self.norm1 = nn.GroupNorm(32,64) + self.conv2 = ln.Conv1d(64,128,3,1,1) + self.norm2 = nn.GroupNorm(32,128) + + self.conv_fundementals = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,128) + self.f0_drop = nn.Dropout() + self.conv_f0 = ln.Conv1d(128,1,1,1,0) + + self.conv_amplitudes = ln.Conv1d(128,2,1,1,0) + self.conv_amplitudes_h = ln.Conv1d(128,2,1,1,0) + # self.conv_loudness = nn.Sequential(ln.Conv1d(n_fft if wavebased else n_mels,128,1,1,0), + # nn.LeakyReLU(0.2), + # ln.Conv1d(128,128,1,1,0), + # nn.LeakyReLU(0.2), + # ln.Conv1d(128,1,1,1,0,bias_initial=0.5),) + self.conv_loudness = nn.Sequential(ln.Conv1d(n_fft if wavebased else n_mels,128,1,1,0), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1,0), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,1,0,bias_initial=-9. if power_synth else -4.6),) + + if self.broud: + self.conv_formants = ln.Conv1d(128,128,3,1,1) + else: + self.conv_formants = ln.Conv1d(128,128,3,1,1) + self.norm_formants = nn.GroupNorm(32,128) + self.conv_formants_freqs = ln.Conv1d(128,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(128,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(128,n_formants,1,1,0) + + self.conv_formants_freqs_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + self.conv_formants_bandwidth_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + self.conv_formants_amplitude_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + + self.amplifier = Parameter(torch.Tensor(1)) + self.bias = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.amplifier,1.0) + nn.init.constant_(self.bias,-0.5) + + def forward(self,x,x_denoise=None,duomask=False,noise_level = None,x_amp=None): + x = x.squeeze(dim=1).permute(0,2,1) #B * f * T + if x_denoise is not None: + x_denoise = x_denoise.squeeze(dim=1).permute(0,2,1) + # x_denoise_amp = amplitude(x_denoise,self.noise_db,self.max_db) + # import pdb; pdb.set_trace() + if x_amp is None: + x_amp = amplitude(x,self.noise_db,self.max_db,trim_noise=True) + else: + x_amp = x_amp.squeeze(dim=1).permute(0,2,1) + hann_win = torch.hann_window(5,periodic=False).reshape([1,1,5,1]) + x_smooth = F.conv2d(x.unsqueeze(1).transpose(-2,-1),hann_win,padding=[2,0]).transpose(-2,-1).squeeze(1) + x_amp_smooth = F.conv2d(x_amp.unsqueeze(1).transpose(-2,-1),hann_win,padding=[2,0]).transpose(-2,-1).squeeze(1) + # loudness = F.softplus(self.amplifier)*(torch.mean(x_denoise_amp,dim=1,keepdim=True)) + # loudness = F.relu(F.softplus(self.amplifier)*(torch.mean(x_amp,dim=1,keepdim=True)-noise_level*0.0003)) + # loudness = torch.mean((x*0.5+0.5) if x_denoise is None else (x_denoise*0.5+0.5),dim=1,keepdim=True) + # loudness = F.softplus(self.amplifier)*(loudness) + # loudness = F.softplus(self.amplifier)*torch.mean(x_amp,dim=1,keepdim=True) + # loudness = F.softplus(self.amplifier)*F.relu(loudness - F.softplus(self.bias)) + if self.power_synth: + loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x_smooth)) + else: + loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x_smooth)) + # loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x)) + # loudness = F.relu(self.conv_loudness(x)) + + # if not self.power_synth: + # loudness = loudness.sqrt() + + if self.broud: + x_narrow = x + x_narrow = F.leaky_relu(self.norm1_narrow(self.conv1_narrow(x_narrow)),0.2) + x_common_narrow = F.leaky_relu(self.norm2_narrow(self.conv2_narrow(x_narrow)),0.2) + amplitudes = F.softmax(self.conv_amplitudes_narrow(x_common_narrow),dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h_narrow(x_common_narrow),dim=1) + x_fundementals_narrow = self.f0_drop_narrow(F.leaky_relu(self.norm_fundementals_narrow(self.conv_fundementals_narrow(x_common_narrow)),0.2)) + + x_amp = amplitude(x.unsqueeze(1),self.noise_db,self.max_db).transpose(-2,-1) + x_mel = to_db(torchaudio.transforms.MelScale(f_max=8000,n_stft=self.n_fft)(x_amp.transpose(-2,-1)),self.noise_db,self.max_db).squeeze(1) + x = F.leaky_relu(self.norm1_mel(self.conv1_mel(x_mel)),0.2) + x_common_mel = F.leaky_relu(self.norm2_mel(self.conv2_mel(x)),0.2) + x_fundementals_mel = self.f0_drop_mel(F.leaky_relu(self.norm_fundementals_mel(self.conv_fundementals_mel(x_common_mel)),0.2)) + + f0_hz = torch.sigmoid(self.conv_f0_narrow(torch.cat([x_fundementals_narrow,x_fundementals_mel],dim=1))) * 120 + 180 # 180hz < f0 < 300 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + hann_win = torch.hann_window(21,periodic=False).reshape([1,1,21,1]) + x = to_db(F.conv2d(x_amp,hann_win,padding=[10,0]).transpose(-2,-1),self.noise_db,self.max_db).squeeze(1) + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x_common = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + + + + else: + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x_common = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + + + # loudness = F.relu(self.conv_loudness(x_common)) + # loudness = F.relu(self.conv_loudness(x_common)) +(10**(self.noise_db/10.-1) if self.wavebased else 0) + amplitudes = F.softmax(self.conv_amplitudes(x_common),dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common),dim=1) + + # x_fundementals = F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2) + x_fundementals = self.f0_drop(F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2)) + # f0 in mel: + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) + # f0 = F.tanh(self.conv_f0(x_fundementals)) * (16/64)*(self.n_mels/64) # 72hz < f0 < 446 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + + # f0 in hz: + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 120 + 180 # 180hz < f0 < 300 hz + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 332 + 88 # 88hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 120 + 180 # 180hz < f0 < 300 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 528 + 88 # 88hz < f0 < 616 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 302 + 118 # 118hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 240 + 180 # 180hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 260 + 160 # 160hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)) + # # relative freq: + # formants_freqs_hz = formants_freqs*(self.formant_freq_limits_diff[:,:self.n_formants]-self.formant_freq_limits_diff_low[:,:self.n_formants])+self.formant_freq_limits_diff_low[:,:self.n_formants] + # # formants_freqs_hz = formants_freqs*6839 + # formants_freqs_hz = torch.cumsum(formants_freqs_hz,dim=1) + + # abs freq: + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + # formants_freqs_hz = formants_freqs*6839 + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + + # formants_freqs = formants_freqs + f0 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) *6839 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 150 + # formants_bandwidth_hz = 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+100) + # formants_bandwidth_hz = (torch.sigmoid(self.conv_formants_bandwidth(x_formants))) * (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100) #good for spec based method + # formants_bandwidth_hz = ((3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.0125*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + # formants_bandwidth_hz = (2**(torch.tanh(self.formant_bandwitdh_slop))*0.001*torch.relu(formants_freqs_hz-4000*torch.sigmoid(self.formant_bandwitdh_thres))+375*2**(torch.tanh(self.formant_bandwitdh_bias))) + # formants_bandwidth_hz = torch.exp(0.4*torch.tanh(self.conv_formants_bandwidth(x_formants))) * (0.00625*torch.relu(formants_freqs_hz-0)+375) + # formants_bandwidth_hz = (100*(torch.tanh(self.conv_formants_bandwidth(x_formants))) + (0.035*torch.relu(formants_freqs_hz-950)+250)) if self.wavebased else ((3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.0125*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + # formants_bandwidth_hz = (100*(torch.tanh(self.conv_formants_bandwidth(x_formants))) + (0.035*torch.relu(formants_freqs_hz-950)+250)) if self.wavebased else ((torch.sigmoid(self.conv_formants_bandwidth(x_formants))+0.2) * (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * (0.075*3*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+(2*torch.sigmoid(self.formant_bandwitdh_ratio)+1)*50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+50) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + # formants_bandwidth_upper = formants_freqs_hz+formants_bandwidth_hz/2 + # formants_bandwidth_lower = torch.clamp(formants_freqs_hz-formants_bandwidth_hz/2,min=1) + # formants_bandwidth = (mel_scale(self.n_mels,formants_bandwidth_upper) - mel_scale(self.n_mels,formants_bandwidth_lower))/(self.n_mels*1.0) + # formants_amplitude = F.softmax(torch.cumsum(-F.relu(self.conv_formants_amplitude(x_formants)),dim=1),dim=1) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)) + # # relative freq: + # formants_freqs_hz = formants_freqs*(self.formant_freq_limits_diff[:,:self.n_formants]-self.formant_freq_limits_diff_low[:,:self.n_formants])+self.formant_freq_limits_diff_low[:,:self.n_formants] + # # formants_freqs_hz = formants_freqs*6839 + # formants_freqs_hz = torch.cumsum(formants_freqs_hz,dim=1) + + # abs freq: + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + if duomask: + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + # formants_freqs_hz = formants_freqs*6839 + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + + # formants_freqs = formants_freqs + f0 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) *6839 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 150 + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 8000 + 2000 + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants) + # formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 8000 + 2000 #2000-10000 + # formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 2000 #0-2000 + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + if duomask: + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 4000 + 1000 + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 4000 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * (0.075*3*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+(2*torch.sigmoid(self.formant_bandwitdh_ratio)+1)*50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+50) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + # formants_bandwidth_upper = formants_freqs_hz+formants_bandwidth_hz/2 + # formants_bandwidth_lower = torch.clamp(formants_freqs_hz-formants_bandwidth_hz/2,min=1) + # formants_bandwidth = (mel_scale(self.n_mels,formants_bandwidth_upper) - mel_scale(self.n_mels,formants_bandwidth_lower))/(self.n_mels*1.0) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants) + if duomask: + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + +class FromECoG(nn.Module): + def __init__(self, outputs,residual=False,shape='3D'): + super().__init__() + self.residual=residual + if shape =='3D': + self.from_ecog = ln.Conv3d(1, outputs, [9,1,1], 1, [4,0,0]) + else: + self.from_ecog = ln.Conv2d(1, outputs, [9,1], 1, [4,0]) + + def forward(self, x): + x = self.from_ecog(x) + if not self.residual: + x = F.leaky_relu(x, 0.2) + return x + +class ECoGMappingBlock(nn.Module): + def __init__(self, inputs, outputs, kernel_size,dilation=1,fused_scale=True,residual=False,resample=[],pool=None,shape='3D'): + super(ECoGMappingBlock, self).__init__() + self.residual = residual + self.pool = pool + self.inputs_resample = resample + self.dim_missmatch = (inputs!=outputs) + self.resample = resample + if not self.resample: + self.resample=1 + self.padding = list(np.array(dilation)*(np.array(kernel_size)-1)//2) + if shape=='2D': + conv=ln.Conv2d + maxpool = nn.MaxPool2d + avgpool = nn.AvgPool2d + if shape=='3D': + conv=ln.Conv3d + maxpool = nn.MaxPool3d + avgpool = nn.AvgPool3d + # self.padding = [dilation[i]*(kernel_size[i]-1)//2 for i in range(len(dilation))] + if residual: + self.norm1 = nn.GroupNorm(min(inputs,32),inputs) + else: + self.norm1 = nn.GroupNorm(min(outputs,32),outputs) + if pool is None: + self.conv1 = conv(inputs, outputs, kernel_size, self.resample, self.padding, dilation=dilation, bias=False) + else: + self.conv1 = conv(inputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.pool1 = maxpool(self.resample,self.resample) if self.pool=='Max' else avgpool(self.resample,self.resample) + if self.inputs_resample or self.dim_missmatch: + if pool is None: + self.convskip = conv(inputs, outputs, kernel_size, self.resample, self.padding, dilation=dilation, bias=False) + else: + self.convskip = conv(inputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.poolskip = maxpool(self.resample,self.resample) if self.pool=='Max' else avgpool(self.resample,self.resample) + + self.conv2 = conv(outputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.norm2 = nn.GroupNorm(min(outputs,32),outputs) + + def forward(self,x): + if self.residual: + x = F.leaky_relu(self.norm1(x),0.2) + if self.inputs_resample or self.dim_missmatch: + # x_skip = F.avg_pool3d(x,self.resample,self.resample) + x_skip = self.convskip(x) + if self.pool is not None: + x_skip = self.poolskip(x_skip) + else: + x_skip = x + x = F.leaky_relu(self.norm2(self.conv1(x)),0.2) + if self.pool is not None: + x = self.poolskip(x) + x = self.conv2(x) + x = x_skip + x + else: + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + return x + + + +@ECOG_ENCODER.register("ECoGMappingBottleneck") +class ECoGMapping_Bottleneck(nn.Module): + def __init__(self,n_mels,n_formants,n_formants_noise=1,compute_db_loudness=True): + super(ECoGMapping_Bottleneck, self).__init__() + self.n_formants = n_formants + self.n_mels = n_mels + self.n_formants_noise = n_formants_noise + + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + + self.formant_bandwitdh_ratio = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_ratio,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + + self.compute_db_loudness = compute_db_loudness + if compute_db_loudness: + self.conv_loudness = ln.Conv1d(32,1,1,1,0) + else: + self.conv_loudness = ln.Conv1d(32,1,1,1,0,bias_initial=-9.) + self.from_ecog = FromECoG(16,residual=True) + self.conv1 = ECoGMappingBlock(16,32,[5,1,1],residual=True,resample = [2,1,1],pool='MAX') + self.conv2 = ECoGMappingBlock(32,64,[3,1,1],residual=True,resample = [2,1,1],pool='MAX') + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv3d(64,1,[3,1,1],1,[1,0,0]) + self.conv3 = ECoGMappingBlock(64,128,[3,3,3],residual=True,resample = [2,2,2],pool='MAX') + self.conv4 = ECoGMappingBlock(128,256,[3,3,3],residual=True,resample = [2,2,2],pool='MAX') + self.norm = nn.GroupNorm(32,256) + self.conv5 = ln.Conv1d(256,256,3,1,1) + self.norm2 = nn.GroupNorm(32,256) + self.conv6 = ln.ConvTranspose1d(256, 128, 3, 2, 1, transform_kernel=True) + self.norm3 = nn.GroupNorm(32,128) + self.conv7 = ln.ConvTranspose1d(128, 64, 3, 2, 1, transform_kernel=True) + self.norm4 = nn.GroupNorm(32,64) + self.conv8 = ln.ConvTranspose1d(64, 32, 3, 2, 1, transform_kernel=True) + self.norm5 = nn.GroupNorm(32,32) + self.conv9 = ln.ConvTranspose1d(32, 32, 3, 2, 1, transform_kernel=True) + self.norm6 = nn.GroupNorm(32,32) + + self.conv_fundementals = ln.Conv1d(32,32,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,32) + self.f0_drop = nn.Dropout() + self.conv_f0 = ln.Conv1d(32,1,1,1,0) + self.conv_amplitudes = ln.Conv1d(32,2,1,1,0) + self.conv_amplitudes_h = ln.Conv1d(32,2,1,1,0) + self.conv_loudness = ln.Conv1d(32,1,1,1,0,bias_initial=-9.) + + self.conv_formants = ln.Conv1d(32,32,3,1,1) + self.norm_formants = nn.GroupNorm(32,32) + self.conv_formants_freqs = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(32,n_formants,1,1,0) + + self.conv_formants_freqs_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_bandwidth_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_amplitude_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + + + def forward(self,ecog,mask_prior,mni): + x_common_all = [] + for d in range(len(ecog)): + x = ecog[d] + x = x.reshape([-1,1,x.shape[1],15,15]) + mask_prior_d = mask_prior[d].reshape(-1,1,1,15,15) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,4:] + if mask_prior is not None: + mask = mask*mask_prior_d + x = x[:,:,4:] + x = x*mask + x = self.conv3(x) + x = self.conv4(x) + x = x.max(-1)[0].max(-1)[0] + x = self.conv5(F.leaky_relu(self.norm(x),0.2)) + x = self.conv6(F.leaky_relu(self.norm2(x),0.2)) + x = self.conv7(F.leaky_relu(self.norm3(x),0.2)) + x = self.conv8(F.leaky_relu(self.norm4(x),0.2)) + x = self.conv9(F.leaky_relu(self.norm5(x),0.2)) + x_common = F.leaky_relu(self.norm6(x),0.2) + x_common_all += [x_common] + + x_common = torch.cat(x_common_all,dim=0) + if self.compute_db_loudness: + loudness = F.sigmoid(self.conv_loudness(x_common)) #0-1 + loudness = loudness*200-100 #-100 ~ 100 db + loudness = 10**(loudness/10.) #amplitude + else: + loudness = F.softplus(self.conv_loudness(x_common)) + logits = self.conv_amplitudes(x_common) + amplitudes = F.softmax(logits,dim=1) + amplitudes_logsoftmax = F.log_softmax(logits,dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common),dim=1) + x_fundementals = self.f0_drop(F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2)) + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 332 + 88 # 88hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)) + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)) + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants) + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants) + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + + +@ECOG_ENCODER.register("ECoGMappingBottlenecklstm1") +class ECoGMapping_Bottlenecklstm1(nn.Module): + def __init__(self,n_mels,n_formants,n_formants_noise=1,compute_db_loudness=True): + super(ECoGMapping_Bottlenecklstm1, self).__init__() + self.n_formants = n_formants + self.n_mels = n_mels + self.n_formants_noise = n_formants_noise + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + print ('******************************************************') + print ('*******************LSTM ECOG ENCODER******************') + print ('******************************************************') + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + + self.formant_bandwitdh_ratio = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_ratio,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + self.compute_db_loudness = compute_db_loudness + if compute_db_loudness: + self.conv_loudness = ln.Conv1d(32,1,1,1,0) + else: + self.conv_loudness = ln.Conv1d(32,1,1,1,0,bias_initial=-9.) + self.from_ecog = FromECoG(16,residual=True) + self.conv1 = ECoGMappingBlock(16,32,[5,1,1],residual=True,resample = [2,1,1],pool='MAX') + self.conv2 = ECoGMappingBlock(32,64,[3,1,1],residual=True,resample = [2,1,1],pool='MAX') + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv3d(64,1,[3,1,1],1,[1,0,0]) + self.conv3 = ECoGMappingBlock(64,64,[3,3,3],residual=True,resample = [1,2,2],pool='MAX') + self.conv4 = ECoGMappingBlock(64,64,[3,3,3],residual=True,resample = [1,2,2],pool='MAX') + #self.norm = nn.GroupNorm(32,256) + #self.conv5 = ln.Conv1d(256,256,3,1,1) + self.norm2 = nn.GroupNorm(32,64) + self.conv6 = ln.ConvTranspose1d(64, 64, 3, 2, 1, transform_kernel=True) + self.norm3 = nn.GroupNorm(32,64) + self.conv7 = ln.ConvTranspose1d(64, 64, 3, 2, 1, transform_kernel=True) + #self.norm4 = nn.GroupNorm(32,64) + #self.conv8 = ln.ConvTranspose1d(64, 32, 3, 2, 1, transform_kernel=True) + #self.norm5 = nn.GroupNorm(32,32) + #self.conv9 = ln.ConvTranspose1d(32, 32, 3, 2, 1, transform_kernel=True) + self.norm6 = nn.GroupNorm(32,32) + self.lstm = nn.LSTM(64, 16, num_layers=3, bidirectional=True, dropout=0.1, batch_first=True) + + self.conv_fundementals = ln.Conv1d(32,32,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,32) + self.f0_drop = nn.Dropout() + self.conv_f0 = ln.Conv1d(32,1,1,1,0) + self.conv_amplitudes = ln.Conv1d(32,2,1,1,0) + self.conv_amplitudes_h = ln.Conv1d(32,2,1,1,0) + self.conv_loudness = ln.Conv1d(32,1,1,1,0,bias_initial=-9.) + + self.conv_formants = ln.Conv1d(32,32,3,1,1) + self.norm_formants = nn.GroupNorm(32,32) + self.conv_formants_freqs = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(32,n_formants,1,1,0) + + self.conv_formants_freqs_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_bandwidth_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_amplitude_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + + def forward(self,ecog,mask_prior,mni): + x_common_all = [] + for d in range(len(ecog)): + x = ecog[d] + x = x.reshape([-1,1,x.shape[1],15,15]) + mask_prior_d = mask_prior[d].reshape(-1,1,1,15,15) + #print ('x,mask_prior_d,mni :', x.shape,mask_prior_d.shape, mni.shape) + x = self.from_ecog(x) + #print ('from ecog: ',x.shape) + x = self.conv1(x) + #print ('conv1 x: ',x.shape) + x = self.conv2(x) + #print ('conv2 x: ',x.shape) + mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,4:] + if mask_prior is not None: + mask = mask*mask_prior_d + #print ('mask: ',mask.shape) + x = x[:,:,4:] + x = x*mask + #print ('attention x: ', x.shape) + x = self.conv3(x) + #print ('conv3 x: ',x.shape) + x = self.conv4(x) + #print ('conv4 x: ',x.shape) + x = x.max(-1)[0].max(-1)[0] + #print ('max x: ',x.shape) + #x = self.conv5(F.leaky_relu(self.norm(x),0.2)) + #print ('conv5 x: ',x.shape) + x = self.conv6(F.leaky_relu(self.norm2(x),0.2)) + #print ('conv6 x: ',x.shape) + x = self.conv7(F.leaky_relu(self.norm3(x),0.2)) + #print ('conv7 x: ',x.shape) + x = x .permute(0,2,1) + #print ('reshape x:',x.shape) + #x = self.conv8(F.leaky_relu(self.norm4(x),0.2)) + #print ('conv8 x: ',x.shape) + #x = self.conv9(F.leaky_relu(self.norm5(x),0.2)) + #print ('conv9 x: ',x.shape) + x = self.lstm(x)[0] + #print ('lstm x:',x.shape) + x = x .permute(0,2,1) + x_common = F.leaky_relu(self.norm6(x),0.2) + #print ('common x: ',x_common.shape) + x_common_all += [x_common] + + x_common = torch.cat(x_common_all,dim=0) + if self.compute_db_loudness: + loudness = F.sigmoid(self.conv_loudness(x_common)) #0-1 + loudness = loudness*200-100 #-100 ~ 100 db + loudness = 10**(loudness/10.) #amplitude + else: + loudness = F.softplus(self.conv_loudness(x_common)) + logits = self.conv_amplitudes(x_common) + amplitudes = F.softmax(logits,dim=1) + amplitudes_logsoftmax = F.log_softmax(logits,dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common),dim=1) + x_fundementals = self.f0_drop(F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2)) + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 332 + 88 # 88hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)) + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)) + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants) + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants) + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + +@ECOG_ENCODER.register("ECoGMappingBottlenecklstm2") +class ECoGMapping_Bottlenecklstm2(nn.Module): + def __init__(self,n_mels,n_formants,n_formants_noise=1,compute_db_loudness=True): + super(ECoGMapping_Bottlenecklstm2, self).__init__() + self.n_formants = n_formants + self.n_mels = n_mels + self.n_formants_noise = n_formants_noise + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + print ('******************************************************') + print ('*******************LSTM ECOG ENCODER******************') + print ('******************************************************') + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + + self.formant_bandwitdh_ratio = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_ratio,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + self.compute_db_loudness = compute_db_loudness + if compute_db_loudness: + self.conv_loudness = ln.Conv1d(32,1,1,1,0) + else: + self.conv_loudness = ln.Conv1d(32,1,1,1,0,bias_initial=-9.) + self.from_ecog = FromECoG(16,residual=True) + self.conv1 = ECoGMappingBlock(16,32,[5,1,1],residual=True,resample = [2,1,1],pool='MAX') + #self.conv2 = ECoGMappingBlock(32,64,[3,1,1],residual=True,resample = [2,1,1],pool='MAX') + #self.norm_mask = nn.GroupNorm(32,64) + #self.mask = ln.Conv3d(64,1,[3,1,1],1,[1,0,0]) + self.conv3 = ECoGMappingBlock(32,32,[3,3,3],residual=True,resample = [1,2,2],pool='MAX') + self.conv4 = ECoGMappingBlock(32,32,[3,3,3],residual=True,resample = [1,2,2],pool='MAX') + #self.norm = nn.GroupNorm(32,256) + #self.conv5 = ln.Conv1d(256,256,3,1,1) + self.norm3 = nn.GroupNorm(32,32) + self.conv7 = ln.ConvTranspose1d(32, 32, 3, 2, 1, transform_kernel=True) + self.norm6 = nn.GroupNorm(32,32) + self.lstm = nn.LSTM(32, 16, num_layers=4, bidirectional=True, dropout=0.1, batch_first=True) + + self.conv_fundementals = ln.Conv1d(32,32,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,32) + self.f0_drop = nn.Dropout() + self.conv_f0 = ln.Conv1d(32,1,1,1,0) + self.conv_amplitudes = ln.Conv1d(32,2,1,1,0) + self.conv_amplitudes_h = ln.Conv1d(32,2,1,1,0) + self.conv_loudness = ln.Conv1d(32,1,1,1,0,bias_initial=-9.) + + self.conv_formants = ln.Conv1d(32,32,3,1,1) + self.norm_formants = nn.GroupNorm(32,32) + self.conv_formants_freqs = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(32,n_formants,1,1,0) + + self.conv_formants_freqs_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_bandwidth_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_amplitude_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + + def forward(self,ecog,mask_prior,mni): + x_common_all = [] + for d in range(len(ecog)): + x = ecog[d] + x = x.reshape([-1,1,x.shape[1],15,15]) + mask_prior_d = mask_prior[d].reshape(-1,1,1,15,15) + #print ('x,mask_prior_d,mni :', x.shape,mask_prior_d.shape, mni.shape) + x = self.from_ecog(x) + #print ('from ecog: ',x.shape) + x = self.conv1(x) + #print ('conv1 x: ',x.shape) + x = x[:,:,4:-4] + #print ('x truncate: ',x.shape) + x = self.conv3(x) + #print ('conv3 x: ',x.shape) + x = self.conv4(x) + #print ('conv4 x: ',x.shape) + x = x.max(-1)[0].max(-1)[0] + #print ('max x: ',x.shape) + x = self.conv7(F.leaky_relu(self.norm3(x),0.2)) + #print ('conv7 x: ',x.shape) + x = x .permute(0,2,1) + #print ('reshape x:',x.shape) + x = self.lstm(x)[0] + #print ('lstm x:',x.shape) + x = x .permute(0,2,1) + x_common = F.leaky_relu(self.norm6(x),0.2) + #print ('common x: ',x_common.shape) + x_common_all += [x_common] + + x_common = torch.cat(x_common_all,dim=0) + if self.compute_db_loudness: + loudness = F.sigmoid(self.conv_loudness(x_common)) #0-1 + loudness = loudness*200-100 #-100 ~ 100 db + loudness = 10**(loudness/10.) #amplitude + else: + loudness = F.softplus(self.conv_loudness(x_common)) + logits = self.conv_amplitudes(x_common) + amplitudes = F.softmax(logits,dim=1) + amplitudes_logsoftmax = F.log_softmax(logits,dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common),dim=1) + x_fundementals = self.f0_drop(F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2)) + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 332 + 88 # 88hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)) + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)) + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants) + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants) + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + + +def RNN_layer(in_ch,out_ch,rnn_type = 'LSTM',rnn_layers = 4,bidirection = True): + dropoutratio = 0 if rnn_layers ==1 else 0.1 + if rnn_type =='LSTM': + return nn.LSTM(in_ch,out_ch, num_layers=rnn_layers, bidirectional=bidirection, batch_first=True,dropout=dropoutratio) + elif rnn_type =='GRU': + return nn.GRU(in_ch,out_ch, num_layers=rnn_layers, bidirectional=bidirection, batch_first=True,dropout=dropoutratio) + +@ECOG_ENCODER.register("ECoGMappingBottlenecklstm_pure") +class ECoGMapping_Bottlenecklstm_new(nn.Module): + def __init__(self,n_mels,n_formants,n_formants_noise=1,onedconfirst=True,rnn_type = 'LSTM',\ + rnn_layers = 4,compute_db_loudness=True,bidirection = True): + super(ECoGMapping_Bottlenecklstm_new, self).__init__() + self.n_formants = n_formants + self.n_mels = n_mels + self.n_formants_noise = n_formants_noise + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + print ('******************************************************') + print ('*******************PURE!! LSTM ECOG ENCODER******************') + print ('******************************************************') + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + self.onedconfirst = onedconfirst + if onedconfirst: + self.prelayer = ln.Conv2d(1,1,(9,1),1,(4,0))#ln.Conv1d(144,144,3,2,1) + else: + self.prelayer = RNN_layer(80, 40,rnn_type = rnn_type, rnn_layers = 1,bidirection = True) + + self.formant_bandwitdh_ratio = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_ratio,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + self.compute_db_loudness = compute_db_loudness + + + self.lstm = RNN_layer(80, 32//(bidirection+1), rnn_layers=rnn_layers, bidirection=bidirection) + + if compute_db_loudness: + self.conv_loudness = RNN_layer(32, 1, rnn_layers=1, bidirection=False) + else: + self.conv_loudness = RNN_layer(32, 1, rnn_layers=1, bidirection=False) + self.conv_fundementals = RNN_layer(32, 32//(bidirection+1), rnn_layers=1, bidirection=bidirection) + #self.norm_fundementals = nn.GroupNorm(32,32) + self.f0_drop = nn.Dropout() + self.conv_f0 = RNN_layer(32, 1, rnn_layers=1, bidirection=False) + self.conv_amplitudes = RNN_layer(32, 2//(bidirection+1), rnn_layers=1, bidirection=bidirection) + self.conv_amplitudes_h = RNN_layer(32, 2//(bidirection+1), rnn_layers=1, bidirection=bidirection) + self.conv_formants = RNN_layer(32, 32//(bidirection+1), rnn_layers=1, bidirection=bidirection) + #self.norm_formants = nn.GroupNorm(32,32) + self.conv_formants_freqs = RNN_layer(32, n_formants//(bidirection+1), rnn_layers=1, bidirection=bidirection) + self.conv_formants_bandwidth = RNN_layer(32, n_formants//(bidirection+1), rnn_layers=1, bidirection=bidirection) + self.conv_formants_amplitude = RNN_layer(32, n_formants//(bidirection+1), rnn_layers=1, bidirection=bidirection) + #self.conv_formants_freqs_noise = RNN_layer(32, n_formants_noise//(bidirection+1), rnn_layers=1, bidirection=bidirection) + #self.conv_formants_bandwidth_noise = RNN_layer(32, n_formants_noise//(bidirection+1), rnn_layers=1, bidirection=bidirection) + #self.conv_formants_amplitude_noise = RNN_layer(32, n_formants_noise//(bidirection+1), rnn_layers=1, bidirection=bidirection) + self.conv_formants_freqs_noise = RNN_layer(32, n_formants_noise, rnn_layers=1, bidirection=False) + self.conv_formants_bandwidth_noise = RNN_layer(32, n_formants_noise , rnn_layers=1, bidirection=False) + self.conv_formants_amplitude_noise = RNN_layer(32, n_formants_noise , rnn_layers=1, bidirection=False) + + + def forward(self,ecog,mask_prior,mni): + x_common_all = [] + for d in range(len(ecog)): + x = ecog[d] + print ('x: ',x.shape) + x = torch.squeeze(self.prelayer(torch.unsqueeze(x,dim=1)),dim=1) if self.onedconfirst else self.prelayer(x)[0] + print ('x prelayer, convfirst ',self.onedconfirst,x.shape) + x = x[:,8:-8] + x = self.lstm(x)[0] + print ('lstm x:',x.shape) + x_common = torch.tanh(x).permute(0,2,1) + print ('common x: ',x_common.shape) + x_common_all += [x_common] + + x_common = torch.cat(x_common_all,dim=0).permute(0,2,1) + print ('common x: ',x_common.shape) + if self.compute_db_loudness: + loudness = torch.sigmoid(self.conv_loudness(x_common)[0].permute(0,2,1)) #0-1 + loudness = loudness*200-100 #-100 ~ 100 db + loudness = 10**(loudness/10.) #amplitude + else: + loudness = F.softplus(self.conv_loudness(x_common)[0].permute(0,2,1)) + logits = self.conv_amplitudes(x_common)[0].permute(0,2,1) + amplitudes = F.softmax(logits,dim=1) + amplitudes_logsoftmax = F.log_softmax(logits,dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common)[0].permute(0,2,1),dim=1) + x_fundementals = self.f0_drop(F.leaky_relu( self.conv_fundementals(x_common)[0],0.2)) + print ('x_fundementals: ',x_fundementals.shape) + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)[0].permute(0,2,1)) * 332 + 88 # 88hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu( self.conv_formants(x_common)[0],0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)[0].permute(0,2,1)) + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants)[0].permute(0,2,1) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)[0].permute(0,2,1)) + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants)[0].permute(0,2,1) + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants)[0].permute(0,2,1) + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + + +class BackBone(nn.Module): + def __init__(self,attentional_mask=True): + super(BackBone, self).__init__() + self.attentional_mask = attentional_mask + self.from_ecog = FromECoG(16,residual=True,shape='2D') + self.conv1 = ECoGMappingBlock(16,32,[5,1],residual=True,resample = [1,1],shape='2D') + self.conv2 = ECoGMappingBlock(32,64,[3,1],residual=True,resample = [1,1],shape='2D') + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv2d(64,1,[3,1],1,[1,0]) + + def forward(self,ecog): + x_common_all = [] + mask_all=[] + for d in range(len(ecog)): + x = ecog[d] + x = x.unsqueeze(1) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + if self.attentional_mask: + mask = F.relu(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,16:] + x = x[:,:,16:] + mask_all +=[mask] + else: + # mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + # mask = mask[:,:,16:] + x = x[:,:,16:] + # x = x*mask + + x_common_all +=[x] + + x_common = torch.cat(x_common_all,dim=0) + if self.attentional_mask: + mask = torch.cat(mask_all,dim=0) + return x_common,mask.squeeze(1) if self.attentional_mask else None + +class ECoGEncoderFormantHeads(nn.Module): + def __init__(self,inputs,n_mels,n_formants): + super(ECoGEncoderFormantHeads,self).__init__() + self.n_mels = n_mels + self.f0 = ln.Conv1d(inputs,1,1) + self.loudness = ln.Conv1d(inputs,1,1) + self.amplitudes = ln.Conv1d(inputs,2,1) + self.freq_formants = ln.Conv1d(inputs,n_formants,1) + self.bandwidth_formants = ln.Conv1d(inputs,n_formants,1) + self.amplitude_formants = ln.Conv1d(inputs,n_formants,1) + + def forward(self,x): + loudness = F.relu(self.loudness(x)) + f0 = torch.sigmoid(self.f0(x)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + amplitudes = F.softmax(self.amplitudes(x),dim=1) + freq_formants = torch.sigmoid(self.freq_formants(x)) + freq_formants = torch.cumsum(freq_formants,dim=1) + bandwidth_formants = torch.sigmoid(self.bandwidth_formants(x)) + amplitude_formants = F.softmax(self.amplitude_formants(x),dim=1) + return {'f0':f0, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'freq_formants':freq_formants, + 'bandwidth_formants':bandwidth_formants, + 'amplitude_formants':amplitude_formants,} + +@ECOG_ENCODER.register("ECoGMappingTransformer") +class ECoGMapping_Transformer(nn.Module): + def __init__(self,n_mels,n_formants,SeqLen=128,hidden_dim=256,dim_feedforward=256,encoder_only=False,attentional_mask=False,n_heads=1,non_local=False): + super(ECoGMapping_Transformer, self).__init__() + self.n_mels = n_mels, + self.n_formant = n_formants, + self.encoder_only = encoder_only, + self.attentional_mask = attentional_mask, + self.backbone = BackBone(attentional_mask=attentional_mask) + self.position_encoding = build_position_encoding(SeqLen,hidden_dim,'MNI') + self.input_proj = ln.Conv2d(64, hidden_dim, kernel_size=1) + if non_local: + Transformer = TransformerNL + else: + Transformer = TransformerTS + self.transformer = Transformer(d_model=hidden_dim, nhead=n_heads, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=dim_feedforward, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False,encoder_only = encoder_only) + self.output_proj = ECoGEncoderFormantHeads(hidden_dim,n_mels,n_formants) + self.query_embed = nn.Embedding(SeqLen, hidden_dim) + + def forward(self,x,mask_prior,mni): + features,mask = self.backbone(x) + pos = self.position_encoding(mni) + hs = self.transformer(self.input_proj(features), mask if self.attentional_mask else None, self.query_embed.weight, pos) + if not self.encoder_only: + hs,encoded = hs + out = self.output_proj(hs) + else: + _,encoded = hs + encoded = encoded.max(-1)[0] + out = self.output_proj(encoded) + return out + + + diff --git a/net_formant_masknormed.py b/net_formant_masknormed.py new file mode 100644 index 00000000..b379f39a --- /dev/null +++ b/net_formant_masknormed.py @@ -0,0 +1,1154 @@ +import os +import pdb +from random import triangular +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F +from torch.nn import Parameter as P +from torch.nn import init +from torch.nn.parameter import Parameter +import numpy as np +import lreq as ln +import math +from registry import * +from transformer_models.position_encoding import build_position_encoding +from transformer_models.transformer import Transformer as TransformerTS +from transformer_models.transformer_nonlocal import Transformer as TransformerNL + +def db(x,noise = -80, slope =35, powerdb=True): + if powerdb: + return ((2*torchaudio.transforms.AmplitudeToDB()(x)).clamp(min=noise)-slope-noise)/slope + else: + return ((torchaudio.transforms.AmplitudeToDB()(x)).clamp(min=noise)-slope-noise)/slope + +# def amplitude(x,noise=-80,slope=35): +# return 10**((x*slope+noise+slope)/20.) + +def amplitude(x,noise_db=-60,max_db=35,trim_noise=False): + if trim_noise: + x_db = (x+1)/2*(max_db-noise_db)+noise_db + if type(x) is np.ndarray: + return 10**(x_db/10)*(np.sign(x_db-noise_db)*0.5+0.5) + else: + return 10**(x_db/10)*((x_db-noise_db).sign()*0.5+0.5) + else: + return 10**(((x+1)/2*(max_db-noise_db)+noise_db)/10) + +def to_db(x,noise_db=-60,max_db=35): + return (torchaudio.transforms.AmplitudeToDB()(x)-noise_db)/(max_db-noise_db)*2-1 + +def wave2spec(wave,n_fft=256,wave_fr=16000,spec_fr=125,noise_db=-60,max_db=22.5,to_db=True,power=2): +# def wave2spec(wave,n_fft=256,wave_fr=16000,spec_fr=125,noise_db=-50,max_db=22.5,to_db=True): + if to_db: + return (torchaudio.transforms.AmplitudeToDB()(torchaudio.transforms.Spectrogram(n_fft*2-1,win_length=n_fft*2-1,hop_length=int(wave_fr/spec_fr),power=power)(wave)).clamp(min=noise_db,max=max_db).transpose(-2,-1)-noise_db)/(max_db-noise_db)*2-1 + else: + return torchaudio.transforms.Spectrogram(n_fft*2-1,win_length=n_fft*2-1,hop_length=int(wave_fr/spec_fr),power=power)(wave).transpose(-2,-1) + + +# def mel_scale(n_mels,hz,min_octave=-31.,max_octave=95.,pt=True): +# def mel_scale(n_mels,hz,min_octave=-58.,max_octave=100.,pt=True): +def mel_scale(n_mels,hz,min_octave=-31.,max_octave=102.,pt=True): + #take absolute hz, return abs mel + # return (torch.log2(hz/440)+31/24)*24*n_mels/126 + if pt: + return (torch.log2(hz/440.)-min_octave/24.)*24*n_mels/(max_octave-min_octave) + else: + return (np.log2(hz/440.)-min_octave/24.)*24*n_mels/(max_octave-min_octave) + +# def inverse_mel_scale(mel,min_octave=-31.,max_octave=95.): +# def inverse_mel_scale(mel,min_octave=-58.,max_octave=100.): +def inverse_mel_scale(mel,min_octave=-31.,max_octave=102.): + #take normalized mel, return absolute hz + # return 440*2**(mel*126/24-31/24) + return 440*2**(mel*(max_octave-min_octave)/24.+min_octave/24.) + +# def mel_scale(n_mels,hz,f_min=160.,f_max=8000.,pt=True): +# #take absolute hz, return abs mel +# # return (torch.log2(hz/440)+31/24)*24*n_mels/126 +# m_min = 2595.0 * np.log10(1.0 + (f_min / 700.0)) +# m_max = 2595.0 * np.log10(1.0 + (f_max / 700.0)) +# m_min_ = m_min + (m_max-m_min)/(n_mels+1) +# m_max_ = m_max +# if pt: +# return (2595.0 * torch.log10(1.0 + (hz / 700.0))-m_min_)/(m_max_-m_min_)*n_mels +# else: +# return (2595.0 * np.log10(1.0 + (hz / 700.0))-m_min_)/(m_max_-m_min_)*n_mels + +# def inverse_mel_scale(mel,f_min=160.,f_max=8000.,n_mels=64): +# #take normalized mel, return absolute hz +# # return 440*2**(mel*126/24-31/24) +# m_min = 2595.0 * np.log10(1.0 + (f_min / 700.0)) +# m_max = 2595.0 * np.log10(1.0 + (f_max / 700.0)) +# m_min_ = m_min + (m_max-m_min)/(n_mels+1) +# m_max_ = m_max +# return 700.0 * (10**((mel*(m_max_-m_min_) + m_min_)/ 2595.0) - 1.0) + +def ind2hz(ind,n_fft,max_freq=8000.): + #input abs ind, output abs hz + return ind/(1.0*n_fft)*max_freq + +def hz2ind(hz,n_fft,max_freq=8000.): + # input abs hz, output abs ind + return hz/(1.0*max_freq)*n_fft + +def bandwidth_mel(freqs_hz,bandwidth_hz,n_mels): + # input hz bandwidth, output abs bandwidth on mel + bandwidth_upper = freqs_hz+bandwidth_hz/2. + bandwidth_lower = torch.clamp(freqs_hz-bandwidth_hz/2.,min=1) + bandwidth = mel_scale(n_mels,bandwidth_upper) - mel_scale(n_mels,bandwidth_lower) + return bandwidth + +def torch_P2R(radii, angles): + return radii * torch.cos(angles),radii * torch.sin(angles) +def inverse_spec_to_audio(spec,n_fft = 511,win_length = 511,hop_length = 128,power_synth=True): + ''' + generate random phase, then use istft to inverse spec to audio + ''' + window = torch.hann_window(win_length) + angles = torch.randn_like(spec).uniform_(0, np.pi*2)#torch.zeros_like(spec)#torch.randn_like(spec).uniform_(0, np.pi*2) + spec = spec**0.5 if power_synth else spec + spec_complex = torch.stack(torch_P2R(spec, angles),dim=-1) #real and image in same dim + return torchaudio.functional.istft(spec_complex, n_fft=n_fft, window=window, center=True, win_length=win_length, hop_length=hop_length) + +@GENERATORS.register("GeneratorFormant") +class FormantSysth(nn.Module): + def __init__(self, n_mels=64, k=100, wavebased=False,n_fft=256,noise_db=-50,max_db=22.5,dbbased=False,add_bgnoise=True,log10=False,noise_from_data=False,return_wave=False,power_synth=False): + super(FormantSysth, self).__init__() + self.wave_fr = 16e3 + self.spec_fr = 125 + self.n_fft = n_fft + self.noise_db=noise_db + self.max_db = max_db + self.n_mels = n_mels + self.k = k + self.dbbased=dbbased + self.log10 = log10 + self.add_bgnoise = add_bgnoise + self.wavebased=wavebased + self.noise_from_data = noise_from_data + self.linear_scale = wavebased + self.return_wave = return_wave + self.power_synth = power_synth + self.timbre = Parameter(torch.Tensor(1,1,n_mels)) + self.timbre_mapping = nn.Sequential( + ln.Conv1d(1,128,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,2,1), + # nn.Sigmoid(), + ) + self.bgnoise_mapping = nn.Sequential( + ln.Conv2d(2,2,[1,5],padding=[0,2],gain=1,bias=False), + # nn.Sigmoid(), + ) + self.noise_mapping = nn.Sequential( + ln.Conv2d(2,2,[1,5],padding=[0,2],gain=1,bias=False), + # nn.Sigmoid(), + ) + + self.bgnoise_mapping2 = nn.Sequential( + ln.Conv1d(1,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,gain=1,bias=False), + # nn.Sigmoid(), + ) + self.noise_mapping2 = nn.Sequential( + ln.Conv1d(1,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,1,gain=1,bias=False), + # nn.Sigmoid(), + ) + self.prior_exp = np.array([0.4963,0.0745,1.9018]) + self.timbre_parameter = Parameter(torch.Tensor(2)) + self.wave_noise_amplifier = Parameter(torch.Tensor(1)) + self.wave_hamon_amplifier = Parameter(torch.Tensor(1)) + + if noise_from_data: + self.bgnoise_amp = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.bgnoise_amp,1) + else: + self.bgnoise_dist = Parameter(torch.Tensor(1,1,1,self.n_fft if self.wavebased else self.n_mels)) + with torch.no_grad(): + nn.init.constant_(self.bgnoise_dist,1.0) + # self.silient = Parameter(torch.Tensor(1,1,n_mels)) + self.silient = -1 + with torch.no_grad(): + nn.init.constant_(self.timbre,1.0) + nn.init.constant_(self.timbre_parameter[0],7) + nn.init.constant_(self.timbre_parameter[1],0.004) + nn.init.constant_(self.wave_noise_amplifier,1) + nn.init.constant_(self.wave_hamon_amplifier,4.) + + # nn.init.constant_(self.silient,-1.0) + +# def formant_mask(self,freq,bandwith,amplitude): +# # freq, bandwith, amplitude: B*formants*time +# freq_cord = torch.arange(self.n_mels) +# time_cord = torch.arange(freq.shape[2]) +# grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) +# grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 +# grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 +# freq = freq.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# bandwith = bandwith.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# amplitude = amplitude.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# # masks = amplitude*torch.exp(-0.693*(grid_freq-freq)**2/(2*(bandwith+0.001)**2)) #B,time,freqchans, formants +# masks = amplitude*torch.exp(-(grid_freq-freq)**2/(2*(bandwith/np.sqrt(2*np.log(2))+0.001)**2)) #B,time,freqchans, formants +# masks = masks.unsqueeze(dim=1) #B,1,time,freqchans, formants +# return masks + + def formant_mask(self,freq_hz,bandwith_hz,amplitude,linear=False, triangle_mask = False,duomask=True, n_formant_noise=1,f0_hz=None,noise=False): + # freq, bandwith, amplitude: B*formants*time + freq_cord = torch.arange(self.n_fft if linear else self.n_mels) + time_cord = torch.arange(freq_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq_hz = ind2hz(grid_freq,self.n_fft,self.wave_fr/2) if linear else inverse_mel_scale(grid_freq/(self.n_mels*1.0)) + freq_hz = freq_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + bandwith_hz = bandwith_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + amplitude = amplitude.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + if self.power_synth: + amplitude = amplitude + alpha = (2*np.sqrt(2*np.log(np.sqrt(2)))) + if not noise: + k = (torch.arange(self.k)+1).reshape([1,self.k,1]) + k_f0 = k*f0_hz #BxkxT + freq_range = (-torch.sign(k_f0-7800)*0.5+0.5) + k_f0 = k_f0.permute([0,2,1]).unsqueeze(-1) #BxTxkx1 + hamonic_dist = (amplitude*(torch.exp(-((k_f0-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2))+1E-6).sqrt()).sum(-1).permute([0,2,1]) #BxkxT + norm = (((hamonic_dist*freq_range)**2).sum(1,keepdim=True)+1E-10).sqrt()+1E-10 #Bx1xT + hamonic_dist = (hamonic_dist*freq_range)/norm # sum_k(hamonic_dist**2) = 1, + if self.return_wave: + t = torch.arange(int(f0_hz.shape[2]/self.spec_fr*self.wave_fr))/(1.0*self.wave_fr) #in second + hamonic_dist = F.interpolate(hamonic_dist,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr),mode = 'linear',align_corners=False) + # if self.wavebased: + if triangle_mask: + if duomask: + # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-(0.693*(grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]+0.01)**2)) + masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-((grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]/alpha+0.01)**2)) + bw = bandwith_hz[...,-n_formant_noise:] + masks_noise = F.relu(amplitude[...,-n_formant_noise:] * (1 - (1-1/np.sqrt(2))*2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())) + # masks_noise = amplitude[...,-n_formant_noise:] * (1 - 2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz[...,-n_formant_noise:])/(bw+0.01)-0.5)*0.5+0.5) + masks = torch.cat([masks_hamon,masks_noise],dim=-1) + else: + masks = F.relu(amplitude * (1 - (1-1/np.sqrt(2))*2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())) + else: + # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) + if self.power_synth: + masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + else: + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + masks = amplitude*(torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2))+1E-6).sqrt() + # masks = amplitude*(torch.exp(-((grid_freq_hz-freq_hz))**2/((2*(bandwith_hz/alpha+0.01)**2)))+1E-10).sqrt() #B,t,freq,formants + if noise: + masks_sum = masks.sum(-1,keepdim=True) + masks = masks/((((masks_sum**2).sum(-2,keepdim=True)/self.n_fft)+1E-10).sqrt()+1E-10) + masks = masks.unsqueeze(dim=1) #B,1,time,freqchans, formants + return masks + else: + masks = masks/norm.squeeze(1).unsqueeze(-1).unsqueeze(-1) + masks = masks.unsqueeze(dim=1) #B,1,time,freqchans, formants + if self.return_wave: + return masks, hamonic_dist#B,1,time,freqchans + else: + return masks + # else: + # if triangle_mask: + # if duomask: + # # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-(0.693*(grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]+0.01)**2)) + # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-((grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]/alpha+0.01)**2)) + # bw = bandwith_hz[...,-n_formant_noise:] + # masks_noise = F.relu(amplitude[...,-n_formant_noise:] * (1 - (1-1/np.sqrt(2))*2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())) + # # masks_noise = amplitude[...,-n_formant_noise:] * (1 - 2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz[...,-n_formant_noise:])/(bw+0.01)-0.5)*0.5+0.5) + # masks = torch.cat([masks_hamon,masks_noise],dim=-1) + # else: + # masks = F.relu(amplitude * (1 - (1-1/np.sqrt(2))*2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())) + # # masks = amplitude * (1 - 2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz)/(bandwith_hz+0.01)-0.5)*0.5+0.5) + # else: + # # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/(2*np.sqrt(2*np.log(2)))+0.01)**2)) #B,time,freqchans, formants + # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) #B,time,freqchans, formants + # masks = masks.unsqueeze(dim=1) #B,1,time,freqchans, formants + # if self.return_wave: + # return masks, hamonic_dist#B,1,time,freqchans + # else: + # return masks + + def voicing_wavebased(self,f0_hz): + #f0: B*1*time, hz + t = torch.arange(int(f0_hz.shape[2]/self.spec_fr*self.wave_fr))/(1.0*self.wave_fr) #in second + t = t.unsqueeze(dim=0).unsqueeze(dim=0) #1, 1, time + k = (torch.arange(self.k)+1).reshape([1,self.k,1]) + f0_hz_interp = F.interpolate(f0_hz,t.shape[-1],mode='linear',align_corners=False) + k_f0 = k*f0_hz_interp + k_f0_sum = 2*np.pi*torch.cumsum(k_f0,-1)/(1.0*self.wave_fr) + wave_k = np.sqrt(2)*torch.sin(k_f0_sum) * (-torch.sign(k_f0-7800)*0.5+0.5) + # wave = 0.12*torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-6000)*0.5+0.5) + # wave = 0.09*torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-self.wave_fr/2)*0.5+0.5) + # wave = 0.09*torch.sigmoid(self.wave_hamon_amplifier) * torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-self.wave_fr/2)*0.5+0.5) + wave = wave_k.sum(dim=1,keepdim=True) + # wave = F.softplus(self.wave_hamon_amplifier) * wave.sum(dim=1,keepdim=True) + spec = wave2spec(wave,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=self.dbbased,power=2. if self.power_synth else 1.) + if self.return_wave: + return spec,wave_k + else: + return spec + + def unvoicing_wavebased(self,f0_hz,bg=False,mapping=True): + # return torch.ones([1,1,f0_hz.shape[2],512]) + # noise = 0.3*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.03*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + if bg: + noise = torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + if mapping: + noise = self.bgnoise_mapping2(noise) + else: + noise = np.sqrt(3.)*(2*torch.rand([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)])-1) + if mapping: + noise = self.noise_mapping2(noise) + # noise = 0.3 * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.3 * F.softplus(self.wave_noise_amplifier) * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + return wave2spec(noise,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=self.dbbased,power=2. if self.power_synth else 1.) + + # def unvoicing_wavebased(self,f0_hz): + # # return torch.ones([1,1,f0_hz.shape[2],512]) + # # noise = 0.3*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.1*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # # noise = 0.3 * torch.sigmoid(self.wave_noise_amplifier) * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # return wave2spec(noise,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=False) + + def voicing_linear(self,f0_hz,bandwith=2.5): + #f0: B*1*time, hz + freq_cord = torch.arange(self.n_fft) + time_cord = torch.arange(f0_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + f0_hz = f0_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, 1 + f0_hz = f0_hz.repeat([1,1,1,self.k]) #B,time,1, self.k + f0_hz = f0_hz*(torch.arange(self.k)+1).reshape([1,1,1,self.k]) + # bandwith=4 + # bandwith_lower = torch.clamp(f0-bandwith/2,min=1) + # bandwith_upper = f0+bandwith/2 + # bandwith = mel_scale(self.n_mels,bandwith_upper) - mel_scale(self.n_mels,bandwith_lower) + f0 = hz2ind(f0_hz,self.n_fft) + + # hamonics = torch.exp(-(grid_freq-f0)**2/(2*sigma**2)) #gaussian + # hamonics = (1-((grid_freq_hz-f0_hz)/(2*bandwith_hz/2))**2)*(-torch.sign(torch.abs(grid_freq_hz-f0_hz)/(2*bandwith_hz)-0.5)*0.5+0.5) #welch + freq_cord_reshape = freq_cord.reshape([1,1,1,self.n_fft]) + hamonics = (1 - 2/bandwith*(grid_freq-f0).abs())*(-torch.sign(torch.abs(grid_freq-f0)/(bandwith)-0.5)*0.5+0.5) #triangular + # hamonics = (1-((grid_freq-f0)/(bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(2.5*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(2.5*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(3*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(3*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = torch.cos(np.pi*torch.abs(grid_freq-f0)/(4*bandwith))**2*(-torch.sign(torch.abs(grid_freq-f0)/(4*bandwith)-0.5)*0.5+0.5) #hanning + # hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) # B,1,T,F + # condition = (torch.sign(freq_cord_reshape-switch)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-slop*(freq_cord_reshape-switch)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)-torch.abs(self.prior_exp_parameter[2])) * torch.exp(-torch.abs(self.prior_exp_parameter[1])*freq_cord.reshape([1,1,1,self.n_mels])) + torch.abs(self.prior_exp_parameter[2]) # B,1,T,F + + # timbre_parameter = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]).unsqueeze(1) + # condition = (torch.sign(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) * F.softplus(timbre_parameter[...,2:3]) + timbre_parameter[...,3:4] # B,1,T,F + # timbre = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]) + # hamonics = (hamonics.sum(dim=-1)*timbre).unsqueeze(dim=1) # B,1,T,F + # hamonics = (hamonics.sum(dim=-1)*self.timbre).unsqueeze(dim=1) # B,1,T,F + + hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) + # hamonics = 180*F.softplus(self.wave_hamon_amplifier)*(hamonics.sum(dim=-1)).unsqueeze(dim=1) + + return hamonics + + def voicing(self,f0_hz): + #f0: B*1*time, hz + freq_cord = torch.arange(self.n_mels) + time_cord = torch.arange(f0_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq_hz = inverse_mel_scale(grid_freq/(self.n_mels*1.0)) + f0_hz = f0_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, 1 + f0_hz = f0_hz.repeat([1,1,1,self.k]) #B,time,1, self.k + f0_hz = f0_hz*(torch.arange(self.k)+1).reshape([1,1,1,self.k]) + if self.log10: + f0_mel = mel_scale(self.n_mels,f0_hz) + band_low_hz = inverse_mel_scale((f0_mel-1)/(self.n_mels*1.0),n_mels = self.n_mels) + band_up_hz = inverse_mel_scale((f0_mel+1)/(self.n_mels*1.0),n_mels = self.n_mels) + bandwith_hz = band_up_hz-band_low_hz + band_low_mel = mel_scale(self.n_mels,band_low_hz) + band_up_mel = mel_scale(self.n_mels,band_up_hz) + bandwith = band_up_mel-band_low_mel + else: + bandwith_hz = 24.7*(f0_hz*4.37/1000+1) + bandwith = bandwidth_mel(f0_hz,bandwith_hz,self.n_mels) + # bandwith_lower = torch.clamp(f0-bandwith/2,min=1) + # bandwith_upper = f0+bandwith/2 + # bandwith = mel_scale(self.n_mels,bandwith_upper) - mel_scale(self.n_mels,bandwith_lower) + f0 = mel_scale(self.n_mels,f0_hz) + switch = mel_scale(self.n_mels,torch.abs(self.timbre_parameter[0])*f0_hz[...,0]).unsqueeze(1) + slop = (torch.abs(self.timbre_parameter[1])*f0_hz[...,0]).unsqueeze(1) + freq_cord_reshape = freq_cord.reshape([1,1,1,self.n_mels]) + if not self.dbbased: + # sigma = bandwith/(np.sqrt(2*np.log(2))); + sigma = bandwith/(2*np.sqrt(2*np.log(2))); + hamonics = torch.exp(-(grid_freq-f0)**2/(2*sigma**2)) #gaussian + # hamonics = (1-((grid_freq_hz-f0_hz)/(2*bandwith_hz/2))**2)*(-torch.sign(torch.abs(grid_freq_hz-f0_hz)/(2*bandwith_hz)-0.5)*0.5+0.5) #welch + else: + # # hamonics = (1-((grid_freq-f0)/(1.75*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(1.75*bandwith)-0.5)*0.5+0.5) #welch + hamonics = (1-((grid_freq-f0)/(2.5*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(2.5*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(3*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(3*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = torch.cos(np.pi*torch.abs(grid_freq-f0)/(4*bandwith))**2*(-torch.sign(torch.abs(grid_freq-f0)/(4*bandwith)-0.5)*0.5+0.5) #hanning + # hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) # B,1,T,F + # condition = (torch.sign(freq_cord_reshape-switch)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-slop*(freq_cord_reshape-switch)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)-torch.abs(self.prior_exp_parameter[2])) * torch.exp(-torch.abs(self.prior_exp_parameter[1])*freq_cord.reshape([1,1,1,self.n_mels])) + torch.abs(self.prior_exp_parameter[2]) # B,1,T,F + + timbre_parameter = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]).unsqueeze(1) + condition = (torch.sign(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*0.5+0.5) + amp = F.softplus(self.wave_hamon_amplifier) if self.dbbased else 180*F.softplus(self.wave_hamon_amplifier) + hamonics = amp * ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) * F.softplus(timbre_parameter[...,2:3]) + timbre_parameter[...,3:4] # B,1,T,F + # timbre = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]) + # hamonics = (hamonics.sum(dim=-1)*timbre).unsqueeze(dim=1) # B,1,T,F + # hamonics = (hamonics.sum(dim=-1)*self.timbre).unsqueeze(dim=1) # B,1,T,F + # return F.softplus(self.wave_hamon_amplifier)*hamonics + return hamonics + + def unvoicing(self,f0,bg=False,mapping=True): + # return (0.25*torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels]))+1 + rnd = torch.randn([f0.shape[0],2,f0.shape[2],self.n_fft if self.wavebased else self.n_mels]) + if mapping: + rnd = self.bgnoise_mapping(rnd) if bg else self.noise_mapping(rnd) + real = rnd[:,0:1] + img = rnd[:,1:2] + if self.dbbased: + return (2*torchaudio.transforms.AmplitudeToDB()(torch.sqrt(real**2 + img**2+1E-10))+80).clamp(min=0)/35 + # return (2*torchaudio.transforms.AmplitudeToDB()(F.softplus(self.wave_noise_amplifier)*torch.sqrt(torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2 + torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2))+80).clamp(min=0)/35 + else: + # return torch.ones([f0.shape[0],1,f0.shape[2],self.n_mels]) + return 180*F.softplus(self.wave_noise_amplifier) * torch.sqrt(real**2 + img**2+1E-10) + # return F.softplus(self.wave_noise_amplifier)*torch.sqrt(torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2 + torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2) + + # return (F.softplus(self.wave_noise_amplifier)) * (0.25*torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels]))+1 + # return torch.ones([f0.shape[0],1,f0.shape[2],self.n_mels]) + + def forward(self,components,enable_hamon_excitation=True,enable_noise_excitation=True,enable_bgnoise=True): + # f0: B*1*T, amplitudes: B*2(voicing,unvoicing)*T, freq_formants,bandwidth_formants,amplitude_formants: B*formants*T + amplitudes = components['amplitudes'].unsqueeze(dim=-1) + amplitudes_h = components['amplitudes_h'].unsqueeze(dim=-1) + loudness = components['loudness'].unsqueeze(dim=-1) + f0_hz = components['f0_hz'] + # import pdb;pdb.set_trace() + if self.wavebased: + # self.hamonics = 1800*F.softplus(self.wave_hamon_amplifier)*self.voicing_linear(f0_hz) + # self.noise = 180*self.unvoicing(f0_hz,bg=False,mapping=False) + # self.bgnoise = 18*self.unvoicing(f0_hz,bg=True,mapping=False) + # import pdb;pdb.set_trace() + self.hamonics = self.voicing_wavebased(f0_hz) + self.noise = self.unvoicing_wavebased(f0_hz,bg=False,mapping=False) + self.bgnoise = self.unvoicing_wavebased(f0_hz,bg=True) + else: + self.hamonics = self.voicing(f0_hz) + self.noise = self.unvoicing(f0_hz,bg=False) + self.bgnoise = self.unvoicing(f0_hz,bg=True) + # freq_formants = components['freq_formants']*self.n_mels + # bandwidth_formants = components['bandwidth_formants']*self.n_mels + # excitation = amplitudes[:,0:1]*hamonics + # excitation = loudness*(amplitudes[:,0:1]*hamonics) + + self.excitation_noise = loudness*(amplitudes[:,-1:])*self.noise if self.power_synth else (loudness*amplitudes[:,-1:]+1E-10).sqrt()*self.noise + duomask = components['freq_formants_noise_hz'].shape[1]>components['freq_formants_hamon_hz'].shape[1] + n_formant_noise = (components['freq_formants_noise_hz'].shape[1]-components['freq_formants_hamon_hz'].shape[1]) if duomask else components['freq_formants_noise_hz'].shape[1] + self.mask_hamon = self.formant_mask(components['freq_formants_hamon_hz'],components['bandwidth_formants_hamon_hz'],components['amplitude_formants_hamon'],linear = self.linear_scale,f0_hz = f0_hz) + self.mask_noise = self.formant_mask(components['freq_formants_noise_hz'],components['bandwidth_formants_noise_hz'],components['amplitude_formants_noise'],linear = self.linear_scale,triangle_mask=False if self.wavebased else True,duomask=duomask,n_formant_noise=n_formant_noise,f0_hz = f0_hz,noise=True) + # self.mask_hamon = self.formant_mask(components['freq_formants_hamon']*self.n_mels,components['bandwidth_formants_hamon'],components['amplitude_formants_hamon']) + # self.mask_noise = self.formant_mask(components['freq_formants_noise']*self.n_mels,components['bandwidth_formants_noise'],components['amplitude_formants_noise']) + if self.return_wave: + self.hamonics,self.hamonics_wave = self.hamonics + self.mask_hamon, self.hamonic_dist = self.mask_hamon + self.mask_noise = self.mask_noise + if self.power_synth: + self.excitation_hamon_wave = F.interpolate(loudness[...,-1].sqrt()*amplitudes[:,0:1][...,-1],self.hamonics_wave.shape[-1],mode='linear',align_corners=False)*self.hamonics_wave + else: + self.excitation_hamon_wave = F.interpolate((loudness[...,-1]*amplitudes[:,0:1][...,-1]+1E-10).sqrt(),self.hamonics_wave.shape[-1],mode='linear',align_corners=False)*self.hamonics_wave + self.hamonics_wave_ = (self.excitation_hamon_wave*self.hamonic_dist).sum(1,keepdim=True) + # self.hamonics = 40*self.hamonics + self.mask_hamon_sum = 1.414*self.mask_hamon.sum(dim=-1) + # self.mask_hamon_sum = self.mask_hamon.sum(dim=-1) + self.mask_noise_sum = self.mask_noise.sum(dim=-1) + bgdist = F.softplus(self.bgnoise_amp)*self.noise_dist if self.noise_from_data else F.softplus(self.bgnoise_dist) + if self.power_synth: + self.excitation_hamon = loudness*(amplitudes[:,0:1])*self.hamonics + else: + self.excitation_hamon = (loudness*amplitudes[:,0:1]+1E-10).sqrt()*self.hamonics + # import pdb;pdb.set_trace() + self.noise_excitation = self.excitation_noise*self.mask_noise_sum + # import pdb; pdb.set_trace() + if self.return_wave: + self.noise_excitation_wave = 2*inverse_spec_to_audio(self.noise_excitation.squeeze(1).permute(0,2,1),n_fft=self.n_fft*2-1,power_synth=self.power_synth) + self.noise_excitation_wave = F.pad(self.noise_excitation_wave,[0,self.hamonics_wave_.shape[2]-self.noise_excitation_wave.shape[1]]) + self.noise_excitation_wave = self.noise_excitation_wave.unsqueeze(1) + self.rec_wave = self.noise_excitation_wave+self.hamonics_wave_ + if self.wavebased: + # import pdb; pdb.set_trace() + bgn = bgdist*self.bgnoise*0.0003 if (self.add_bgnoise and enable_bgnoise) else 0 + speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.noise_excitation if enable_noise_excitation else 0) + bgn + # speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.excitation_noise*self.mask_noise_sum if enable_noise_excitation else 0) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + # speech = speech if self.power_synth else speech**2 + speech = (torchaudio.transforms.AmplitudeToDB()(speech).clamp(min=self.noise_db)-self.noise_db)/(self.max_db-self.noise_db)*2-1 + else: + # speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + (self.silient*torch.ones(self.mask_hamon_sum.shape) if self.dbbased else 0) + speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.noise_excitation if enable_noise_excitation else 0) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + (self.silient*torch.ones(self.mask_hamon_sum.shape) if self.dbbased else 0) + # speech = self.excitation_hamon*self.mask_hamon_sum + (self.excitation_noise*self.mask_noise_sum if enable_noise_excitation else 0) + self.silient*torch.ones(self.mask_hamon_sum.shape) + if not self.dbbased: + speech = db(speech) + # import pdb; pdb.set_trace() + if self.return_wave: + return speech,self.rec_wave + else: + return speech + +@ENCODERS.register("EncoderFormant") +class FormantEncoder(nn.Module): + def __init__(self, n_mels=64, n_formants=4,n_formants_noise=2,min_octave=-31,max_octave=96,wavebased=False,n_fft=256,noise_db=-50,max_db=22.5,broud=True,power_synth=False,hop_length=128): + super(FormantEncoder, self).__init__() + self.wavebased = wavebased + self.n_mels = n_mels + self.n_formants = n_formants + self.n_formants_noise = n_formants_noise + self.min_octave = min_octave + self.max_octave = max_octave + self.noise_db = noise_db + self.max_db = max_db + self.broud = broud + self.n_fft = n_fft + self.power_synth=power_synth + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,700.,2700.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + # self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,500.,500.]).reshape([1,3,1]) #freq difference + + self.formant_bandwitdh_bias = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_thres = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_bias,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + nn.init.constant_(self.formant_bandwitdh_thres,0) + + # self.formant_freq_limits = torch.cumsum(self.formant_freq_limits_diff,dim=0) + # self.formant_freq_limits_mel = torch.cat([torch.tensor([0.]),mel_scale(n_mels,self.formant_freq_limits)/n_mels]) + # self.formant_freq_limits_mel_diff = torch.reshape(self.formant_freq_limits_mel[1:]-self.formant_freq_limits_mel[:-1],[1,3,1]) + if broud: + if wavebased: + self.conv1_narrow = ln.Conv1d(n_fft,64,3,1,1) + self.conv1_mel = ln.Conv1d(128,64,3,1,1) + self.norm1_mel = nn.GroupNorm(32,64) + self.conv2_mel = ln.Conv1d(64,128,3,1,1) + self.norm2_mel = nn.GroupNorm(32,128) + self.conv_fundementals_mel = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals_mel = nn.GroupNorm(32,128) + self.f0_drop_mel = nn.Dropout() + else: + self.conv1_narrow = ln.Conv1d(n_mels,64,3,1,1) + self.norm1_narrow = nn.GroupNorm(32,64) + self.conv2_narrow = ln.Conv1d(64,128,3,1,1) + self.norm2_narrow = nn.GroupNorm(32,128) + + self.conv_fundementals_narrow = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals_narrow = nn.GroupNorm(32,128) + self.f0_drop_narrow = nn.Dropout() + if wavebased: + self.conv_f0_narrow = ln.Conv1d(256,1,1,1,0) + else: + self.conv_f0_narrow = ln.Conv1d(128,1,1,1,0) + + self.conv_amplitudes_narrow = ln.Conv1d(128,2,1,1,0) + self.conv_amplitudes_h_narrow = ln.Conv1d(128,2,1,1,0) + + if wavebased: + self.conv1 = ln.Conv1d(n_fft,64,3,1,1) + else: + self.conv1 = ln.Conv1d(n_mels,64,3,1,1) + self.norm1 = nn.GroupNorm(32,64) + self.conv2 = ln.Conv1d(64,128,3,1,1) + self.norm2 = nn.GroupNorm(32,128) + + self.conv_fundementals = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,128) + self.f0_drop = nn.Dropout() + self.conv_f0 = ln.Conv1d(128,1,1,1,0) + + self.conv_amplitudes = ln.Conv1d(128,2,1,1,0) + self.conv_amplitudes_h = ln.Conv1d(128,2,1,1,0) + # self.conv_loudness = nn.Sequential(ln.Conv1d(n_fft if wavebased else n_mels,128,1,1,0), + # nn.LeakyReLU(0.2), + # ln.Conv1d(128,128,1,1,0), + # nn.LeakyReLU(0.2), + # ln.Conv1d(128,1,1,1,0,bias_initial=0.5),) + self.conv_loudness = nn.Sequential(ln.Conv1d(n_fft if wavebased else n_mels,128,1,1,0), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1,0), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,1,0,bias_initial=-9.),) + + if self.broud: + self.conv_formants = ln.Conv1d(128,128,3,1,1) + else: + self.conv_formants = ln.Conv1d(128,128,3,1,1) + self.norm_formants = nn.GroupNorm(32,128) + self.conv_formants_freqs = ln.Conv1d(128,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(128,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(128,n_formants,1,1,0) + + self.conv_formants_freqs_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + self.conv_formants_bandwidth_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + self.conv_formants_amplitude_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + + self.amplifier = Parameter(torch.Tensor(1)) + self.bias = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.amplifier,1.0) + nn.init.constant_(self.bias,-0.5) + + def forward(self,x,x_denoise=None,duomask=False,noise_level = None,x_amp=None): + x = x.squeeze(dim=1).permute(0,2,1) #B * f * T + if x_denoise is not None: + x_denoise = x_denoise.squeeze(dim=1).permute(0,2,1) + # x_denoise_amp = amplitude(x_denoise,self.noise_db,self.max_db) + # import pdb; pdb.set_trace() + if x_amp is None: + x_amp = amplitude(x,self.noise_db,self.max_db,trim_noise=True) + else: + x_amp = x_amp.squeeze(dim=1).permute(0,2,1) + # import pdb; pdb.set_trace() + hann_win = torch.hann_window(5,periodic=False).reshape([1,1,5,1]) + x_smooth = F.conv2d(x.unsqueeze(1).transpose(-2,-1),hann_win,padding=[2,0]).transpose(-2,-1).squeeze(1) + x_amp_smooth = F.conv2d(x_amp.unsqueeze(1).transpose(-2,-1),hann_win,padding=[2,0]).transpose(-2,-1).squeeze(1) + # loudness = F.softplus(self.amplifier)*(torch.mean(x_denoise_amp,dim=1,keepdim=True)) + # loudness = F.relu(F.softplus(self.amplifier)*(torch.mean(x_amp,dim=1,keepdim=True)-noise_level*0.0003)) + # loudness = torch.mean((x*0.5+0.5) if x_denoise is None else (x_denoise*0.5+0.5),dim=1,keepdim=True) + # loudness = F.softplus(self.amplifier)*(loudness) + # loudness = F.softplus(self.amplifier)*torch.mean(x_amp,dim=1,keepdim=True) + # loudness = F.softplus(self.amplifier)*F.relu(loudness - F.softplus(self.bias)) + if self.power_synth: + loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x_smooth)) + else: + # loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x_amp**2)) + loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x_smooth)) + # loudness = F.relu(self.conv_loudness(x)) + + # if not self.power_synth: + # loudness = loudness.sqrt() + + if self.broud: + x_narrow = x + x_narrow = F.leaky_relu(self.norm1_narrow(self.conv1_narrow(x_narrow)),0.2) + x_common_narrow = F.leaky_relu(self.norm2_narrow(self.conv2_narrow(x_narrow)),0.2) + amplitudes = F.softmax(self.conv_amplitudes_narrow(x_common_narrow),dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h_narrow(x_common_narrow),dim=1) + x_fundementals_narrow = self.f0_drop_narrow(F.leaky_relu(self.norm_fundementals_narrow(self.conv_fundementals_narrow(x_common_narrow)),0.2)) + + x_amp = amplitude(x.unsqueeze(1),self.noise_db,self.max_db).transpose(-2,-1) + x_mel = to_db(torchaudio.transforms.MelScale(f_max=8000,n_stft=self.n_fft)(x_amp.transpose(-2,-1)),self.noise_db,self.max_db).squeeze(1) + x = F.leaky_relu(self.norm1_mel(self.conv1_mel(x_mel)),0.2) + x_common_mel = F.leaky_relu(self.norm2_mel(self.conv2_mel(x)),0.2) + x_fundementals_mel = self.f0_drop_mel(F.leaky_relu(self.norm_fundementals_mel(self.conv_fundementals_mel(x_common_mel)),0.2)) + + f0_hz = torch.sigmoid(self.conv_f0_narrow(torch.cat([x_fundementals_narrow,x_fundementals_mel],dim=1))) * 120 + 180 # 180hz < f0 < 300 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + hann_win = torch.hann_window(21,periodic=False).reshape([1,1,21,1]) + x = to_db(F.conv2d(x_amp,hann_win,padding=[10,0]).transpose(-2,-1),self.noise_db,self.max_db).squeeze(1) + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x_common = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + + + + else: + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x_common = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + + + # loudness = F.relu(self.conv_loudness(x_common)) + # loudness = F.relu(self.conv_loudness(x_common)) +(10**(self.noise_db/10.-1) if self.wavebased else 0) + amplitudes = F.softmax(self.conv_amplitudes(x_common),dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common),dim=1) + + # x_fundementals = F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2) + x_fundementals = self.f0_drop(F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2)) + # f0 in mel: + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) + # f0 = F.tanh(self.conv_f0(x_fundementals)) * (16/64)*(self.n_mels/64) # 72hz < f0 < 446 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + + # f0 in hz: + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 120 + 180 # 180hz < f0 < 300 hz + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 332 + 88 # 88hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 120 + 180 # 180hz < f0 < 300 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 528 + 88 # 88hz < f0 < 616 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 302 + 118 # 118hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 240 + 180 # 180hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 260 + 160 # 160hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)) + # # relative freq: + # formants_freqs_hz = formants_freqs*(self.formant_freq_limits_diff[:,:self.n_formants]-self.formant_freq_limits_diff_low[:,:self.n_formants])+self.formant_freq_limits_diff_low[:,:self.n_formants] + # # formants_freqs_hz = formants_freqs*6839 + # formants_freqs_hz = torch.cumsum(formants_freqs_hz,dim=1) + + # abs freq: + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + # formants_freqs_hz = formants_freqs*6839 + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + + # formants_freqs = formants_freqs + f0 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) *6839 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 150 + # formants_bandwidth_hz = 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+100) + # formants_bandwidth_hz = (torch.sigmoid(self.conv_formants_bandwidth(x_formants))) * (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100) #good for spec based method + # formants_bandwidth_hz = ((3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.0125*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + # formants_bandwidth_hz = (2**(torch.tanh(self.formant_bandwitdh_slop))*0.001*torch.relu(formants_freqs_hz-4000*torch.sigmoid(self.formant_bandwitdh_thres))+375*2**(torch.tanh(self.formant_bandwitdh_bias))) + # formants_bandwidth_hz = torch.exp(0.4*torch.tanh(self.conv_formants_bandwidth(x_formants))) * (0.00625*torch.relu(formants_freqs_hz-0)+375) + # formants_bandwidth_hz = (100*(torch.tanh(self.conv_formants_bandwidth(x_formants))) + (0.035*torch.relu(formants_freqs_hz-950)+250)) if self.wavebased else ((3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.0125*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + # formants_bandwidth_hz = (100*(torch.tanh(self.conv_formants_bandwidth(x_formants))) + (0.035*torch.relu(formants_freqs_hz-950)+250)) if self.wavebased else ((torch.sigmoid(self.conv_formants_bandwidth(x_formants))+0.2) * (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * (0.075*3*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+(2*torch.sigmoid(self.formant_bandwitdh_ratio)+1)*50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+50) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + # formants_bandwidth_upper = formants_freqs_hz+formants_bandwidth_hz/2 + # formants_bandwidth_lower = torch.clamp(formants_freqs_hz-formants_bandwidth_hz/2,min=1) + # formants_bandwidth = (mel_scale(self.n_mels,formants_bandwidth_upper) - mel_scale(self.n_mels,formants_bandwidth_lower))/(self.n_mels*1.0) + # formants_amplitude = F.softmax(torch.cumsum(-F.relu(self.conv_formants_amplitude(x_formants)),dim=1),dim=1) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)) + # # relative freq: + # formants_freqs_hz = formants_freqs*(self.formant_freq_limits_diff[:,:self.n_formants]-self.formant_freq_limits_diff_low[:,:self.n_formants])+self.formant_freq_limits_diff_low[:,:self.n_formants] + # # formants_freqs_hz = formants_freqs*6839 + # formants_freqs_hz = torch.cumsum(formants_freqs_hz,dim=1) + + # abs freq: + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + if duomask: + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + # formants_freqs_hz = formants_freqs*6839 + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + + # formants_freqs = formants_freqs + f0 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) *6839 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 150 + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 8000 + 2000 + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants) + # formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 8000 + 2000 #2000-10000 + # formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 2000 #0-2000 + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + if duomask: + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 4000 + 1000 + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 4000 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * (0.075*3*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+(2*torch.sigmoid(self.formant_bandwitdh_ratio)+1)*50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+50) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + # formants_bandwidth_upper = formants_freqs_hz+formants_bandwidth_hz/2 + # formants_bandwidth_lower = torch.clamp(formants_freqs_hz-formants_bandwidth_hz/2,min=1) + # formants_bandwidth = (mel_scale(self.n_mels,formants_bandwidth_upper) - mel_scale(self.n_mels,formants_bandwidth_lower))/(self.n_mels*1.0) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants) + if duomask: + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + +class FromECoG(nn.Module): + def __init__(self, outputs,residual=False,shape='3D'): + super().__init__() + self.residual=residual + if shape =='3D': + self.from_ecog = ln.Conv3d(1, outputs, [9,1,1], 1, [4,0,0]) + else: + self.from_ecog = ln.Conv2d(1, outputs, [9,1], 1, [4,0]) + + def forward(self, x): + x = self.from_ecog(x) + if not self.residual: + x = F.leaky_relu(x, 0.2) + return x + +class ECoGMappingBlock(nn.Module): + def __init__(self, inputs, outputs, kernel_size,dilation=1,fused_scale=True,residual=False,resample=[],pool=None,shape='3D'): + super(ECoGMappingBlock, self).__init__() + self.residual = residual + self.pool = pool + self.inputs_resample = resample + self.dim_missmatch = (inputs!=outputs) + self.resample = resample + if not self.resample: + self.resample=1 + self.padding = list(np.array(dilation)*(np.array(kernel_size)-1)//2) + if shape=='2D': + conv=ln.Conv2d + maxpool = nn.MaxPool2d + avgpool = nn.AvgPool2d + if shape=='3D': + conv=ln.Conv3d + maxpool = nn.MaxPool3d + avgpool = nn.AvgPool3d + # self.padding = [dilation[i]*(kernel_size[i]-1)//2 for i in range(len(dilation))] + if residual: + self.norm1 = nn.GroupNorm(min(inputs,32),inputs) + else: + self.norm1 = nn.GroupNorm(min(outputs,32),outputs) + if pool is None: + self.conv1 = conv(inputs, outputs, kernel_size, self.resample, self.padding, dilation=dilation, bias=False) + else: + self.conv1 = conv(inputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.pool1 = maxpool(self.resample,self.resample) if self.pool=='Max' else avgpool(self.resample,self.resample) + if self.inputs_resample or self.dim_missmatch: + if pool is None: + self.convskip = conv(inputs, outputs, kernel_size, self.resample, self.padding, dilation=dilation, bias=False) + else: + self.convskip = conv(inputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.poolskip = maxpool(self.resample,self.resample) if self.pool=='Max' else avgpool(self.resample,self.resample) + + self.conv2 = conv(outputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.norm2 = nn.GroupNorm(min(outputs,32),outputs) + + def forward(self,x): + if self.residual: + x = F.leaky_relu(self.norm1(x),0.2) + if self.inputs_resample or self.dim_missmatch: + # x_skip = F.avg_pool3d(x,self.resample,self.resample) + x_skip = self.convskip(x) + if self.pool is not None: + x_skip = self.poolskip(x_skip) + else: + x_skip = x + x = F.leaky_relu(self.norm2(self.conv1(x)),0.2) + if self.pool is not None: + x = self.poolskip(x) + x = self.conv2(x) + x = x_skip + x + else: + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + return x + + + +@ECOG_ENCODER.register("ECoGMappingBottleneck") +class ECoGMapping_Bottleneck(nn.Module): + def __init__(self,n_mels,n_formants,n_formants_noise=1): + super(ECoGMapping_Bottleneck, self).__init__() + self.n_formants = n_formants + self.n_mels = n_mels + self.n_formants_noise = n_formants_noise + + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + + self.formant_bandwitdh_ratio = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_ratio,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + + + self.from_ecog = FromECoG(16,residual=True) + self.conv1 = ECoGMappingBlock(16,32,[5,1,1],residual=True,resample = [2,1,1],pool='MAX') + self.conv2 = ECoGMappingBlock(32,64,[3,1,1],residual=True,resample = [2,1,1],pool='MAX') + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv3d(64,1,[3,1,1],1,[1,0,0]) + self.conv3 = ECoGMappingBlock(64,128,[3,3,3],residual=True,resample = [2,2,2],pool='MAX') + self.conv4 = ECoGMappingBlock(128,256,[3,3,3],residual=True,resample = [2,2,2],pool='MAX') + self.norm = nn.GroupNorm(32,256) + self.conv5 = ln.Conv1d(256,256,3,1,1) + self.norm2 = nn.GroupNorm(32,256) + self.conv6 = ln.ConvTranspose1d(256, 128, 3, 2, 1, transform_kernel=True) + self.norm3 = nn.GroupNorm(32,128) + self.conv7 = ln.ConvTranspose1d(128, 64, 3, 2, 1, transform_kernel=True) + self.norm4 = nn.GroupNorm(32,64) + self.conv8 = ln.ConvTranspose1d(64, 32, 3, 2, 1, transform_kernel=True) + self.norm5 = nn.GroupNorm(32,32) + self.conv9 = ln.ConvTranspose1d(32, 32, 3, 2, 1, transform_kernel=True) + self.norm6 = nn.GroupNorm(32,32) + + self.conv_fundementals = ln.Conv1d(32,32,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,32) + self.f0_drop = nn.Dropout() + self.conv_f0 = ln.Conv1d(32,1,1,1,0) + self.conv_amplitudes = ln.Conv1d(32,2,1,1,0) + self.conv_amplitudes_h = ln.Conv1d(32,2,1,1,0) + self.conv_loudness = ln.Conv1d(32,1,1,1,0,bias_initial=-9.) + + self.conv_formants = ln.Conv1d(32,32,3,1,1) + self.norm_formants = nn.GroupNorm(32,32) + self.conv_formants_freqs = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(32,n_formants,1,1,0) + + self.conv_formants_freqs_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_bandwidth_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_amplitude_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + + + def forward(self,ecog,mask_prior,mni): + x_common_all = [] + for d in range(len(ecog)): + x = ecog[d] + x = x.reshape([-1,1,x.shape[1],15,15]) + mask_prior_d = mask_prior[d].reshape(-1,1,1,15,15) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,4:] + if mask_prior is not None: + mask = mask*mask_prior_d + x = x[:,:,4:] + x = x*mask + x = self.conv3(x) + x = self.conv4(x) + x = x.max(-1)[0].max(-1)[0] + x = self.conv5(F.leaky_relu(self.norm(x),0.2)) + x = self.conv6(F.leaky_relu(self.norm2(x),0.2)) + x = self.conv7(F.leaky_relu(self.norm3(x),0.2)) + x = self.conv8(F.leaky_relu(self.norm4(x),0.2)) + x = self.conv9(F.leaky_relu(self.norm5(x),0.2)) + x_common = F.leaky_relu(self.norm6(x),0.2) + x_common_all += [x_common] + + x_common = torch.cat(x_common_all,dim=0) + loudness = F.softplus(self.conv_loudness(x_common)) + amplitudes = F.softmax(self.conv_amplitudes(x_common),dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common),dim=1) + + # x_fundementals = F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2) + x_fundementals = self.f0_drop(F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2)) + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) + # f0 = F.tanh(self.conv_f0(x_fundementals)) * (16/64)*(self.n_mels/64) # 72hz < f0 < 446 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + + # f0 in hz: + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * 528 + 88 # 88hz < f0 < 616 hz + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 332 + 88 # 88hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)) + # formants_freqs = torch.cumsum(formants_freqs,dim=1) + # formants_freqs = formants_freqs + + # abs freq + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + + # formants_freqs = formants_freqs + f0 + # formants_bandwidth = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) + # formants_bandwidth_hz = (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100) + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + # formants_bandwidth_hz = (torch.sigmoid(self.conv_formants_bandwidth(x_formants))) * (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)) + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + # formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:1]-self.formant_freq_limits_abs_noise_low[:,:1])+self.formant_freq_limits_abs_noise_low[:,:1] + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + # formants_bandwidth_hz_noise = F.relu(self.conv_formants_bandwidth_noise(x_formants)) * 8000 + 2000 + # formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + # formants_amplitude_noise = F.softmax(self.conv_formants_amplitude_noise(x_formants),dim=1) + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants) + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants) + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + + +class BackBone(nn.Module): + def __init__(self,attentional_mask=True): + super(BackBone, self).__init__() + self.attentional_mask = attentional_mask + self.from_ecog = FromECoG(16,residual=True,shape='2D') + self.conv1 = ECoGMappingBlock(16,32,[5,1],residual=True,resample = [1,1],shape='2D') + self.conv2 = ECoGMappingBlock(32,64,[3,1],residual=True,resample = [1,1],shape='2D') + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv2d(64,1,[3,1],1,[1,0]) + + def forward(self,ecog): + x_common_all = [] + mask_all=[] + for d in range(len(ecog)): + x = ecog[d] + x = x.unsqueeze(1) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + if self.attentional_mask: + mask = F.relu(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,16:] + x = x[:,:,16:] + mask_all +=[mask] + else: + # mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + # mask = mask[:,:,16:] + x = x[:,:,16:] + # x = x*mask + + x_common_all +=[x] + + x_common = torch.cat(x_common_all,dim=0) + if self.attentional_mask: + mask = torch.cat(mask_all,dim=0) + return x_common,mask.squeeze(1) if self.attentional_mask else None + +class ECoGEncoderFormantHeads(nn.Module): + def __init__(self,inputs,n_mels,n_formants): + super(ECoGEncoderFormantHeads,self).__init__() + self.n_mels = n_mels + self.f0 = ln.Conv1d(inputs,1,1) + self.loudness = ln.Conv1d(inputs,1,1) + self.amplitudes = ln.Conv1d(inputs,2,1) + self.freq_formants = ln.Conv1d(inputs,n_formants,1) + self.bandwidth_formants = ln.Conv1d(inputs,n_formants,1) + self.amplitude_formants = ln.Conv1d(inputs,n_formants,1) + + def forward(self,x): + loudness = F.relu(self.loudness(x)) + f0 = torch.sigmoid(self.f0(x)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + amplitudes = F.softmax(self.amplitudes(x),dim=1) + freq_formants = torch.sigmoid(self.freq_formants(x)) + freq_formants = torch.cumsum(freq_formants,dim=1) + bandwidth_formants = torch.sigmoid(self.bandwidth_formants(x)) + amplitude_formants = F.softmax(self.amplitude_formants(x),dim=1) + return {'f0':f0, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'freq_formants':freq_formants, + 'bandwidth_formants':bandwidth_formants, + 'amplitude_formants':amplitude_formants,} + +@ECOG_ENCODER.register("ECoGMappingTransformer") +class ECoGMapping_Transformer(nn.Module): + def __init__(self,n_mels,n_formants,SeqLen=128,hidden_dim=256,dim_feedforward=256,encoder_only=False,attentional_mask=False,n_heads=1,non_local=False): + super(ECoGMapping_Transformer, self).__init__() + self.n_mels = n_mels, + self.n_formant = n_formants, + self.encoder_only = encoder_only, + self.attentional_mask = attentional_mask, + self.backbone = BackBone(attentional_mask=attentional_mask) + self.position_encoding = build_position_encoding(SeqLen,hidden_dim,'MNI') + self.input_proj = ln.Conv2d(64, hidden_dim, kernel_size=1) + if non_local: + Transformer = TransformerNL + else: + Transformer = TransformerTS + self.transformer = Transformer(d_model=hidden_dim, nhead=n_heads, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=dim_feedforward, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False,encoder_only = encoder_only) + self.output_proj = ECoGEncoderFormantHeads(hidden_dim,n_mels,n_formants) + self.query_embed = nn.Embedding(SeqLen, hidden_dim) + + def forward(self,x,mask_prior,mni): + features,mask = self.backbone(x) + pos = self.position_encoding(mni) + hs = self.transformer(self.input_proj(features), mask if self.attentional_mask else None, self.query_embed.weight, pos) + if not self.encoder_only: + hs,encoded = hs + out = self.output_proj(hs) + else: + _,encoded = hs + encoded = encoded.max(-1)[0] + out = self.output_proj(encoded) + return out + + + diff --git a/net_formant_wave2specbased.py b/net_formant_wave2specbased.py new file mode 100644 index 00000000..9e82750e --- /dev/null +++ b/net_formant_wave2specbased.py @@ -0,0 +1,1161 @@ +import os +import pdb +from random import triangular +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F +from torch.nn import Parameter as P +from torch.nn import init +from torch.nn.parameter import Parameter +import numpy as np +import lreq as ln +import math +from registry import * +from transformer_models.position_encoding import build_position_encoding +from transformer_models.transformer import Transformer as TransformerTS +from transformer_models.transformer_nonlocal import Transformer as TransformerNL + +def db(x,noise = -80, slope =35, powerdb=True): + if powerdb: + return ((2*torchaudio.transforms.AmplitudeToDB()(x)).clamp(min=noise)-slope-noise)/slope + else: + return ((torchaudio.transforms.AmplitudeToDB()(x)).clamp(min=noise)-slope-noise)/slope + +# def amplitude(x,noise=-80,slope=35): +# return 10**((x*slope+noise+slope)/20.) + +def amplitude(x,noise_db=-60,max_db=35,trim_noise=False): + if trim_noise: + x_db = (x+1)/2*(max_db-noise_db)+noise_db + if type(x) is np.ndarray: + return 10**(x_db/10)*(np.sign(x_db-noise_db)*0.5+0.5) + else: + return 10**(x_db/10)*((x_db-noise_db).sign()*0.5+0.5) + else: + return 10**(((x+1)/2*(max_db-noise_db)+noise_db)/10) + +def to_db(x,noise_db=-60,max_db=35): + return (torchaudio.transforms.AmplitudeToDB()(x)-noise_db)/(max_db-noise_db)*2-1 + +def wave2spec(wave,n_fft=256,wave_fr=16000,spec_fr=125,noise_db=-60,max_db=22.5,to_db=True,power=2): +# def wave2spec(wave,n_fft=256,wave_fr=16000,spec_fr=125,noise_db=-50,max_db=22.5,to_db=True): + if to_db: + return (torchaudio.transforms.AmplitudeToDB()(torchaudio.transforms.Spectrogram(n_fft*2-1,win_length=n_fft*2-1,hop_length=int(wave_fr/spec_fr),power=power)(wave)).clamp(min=noise_db,max=max_db).transpose(-2,-1)-noise_db)/(max_db-noise_db)*2-1 + else: + return torchaudio.transforms.Spectrogram(n_fft*2-1,win_length=n_fft*2-1,hop_length=int(wave_fr/spec_fr),power=power)(wave).transpose(-2,-1) + + +# def mel_scale(n_mels,hz,min_octave=-31.,max_octave=95.,pt=True): +# def mel_scale(n_mels,hz,min_octave=-58.,max_octave=100.,pt=True): +def mel_scale(n_mels,hz,min_octave=-31.,max_octave=102.,pt=True): + #take absolute hz, return abs mel + # return (torch.log2(hz/440)+31/24)*24*n_mels/126 + if pt: + return (torch.log2(hz/440.)-min_octave/24.)*24*n_mels/(max_octave-min_octave) + else: + return (np.log2(hz/440.)-min_octave/24.)*24*n_mels/(max_octave-min_octave) + +# def inverse_mel_scale(mel,min_octave=-31.,max_octave=95.): +# def inverse_mel_scale(mel,min_octave=-58.,max_octave=100.): +def inverse_mel_scale(mel,min_octave=-31.,max_octave=102.): + #take normalized mel, return absolute hz + # return 440*2**(mel*126/24-31/24) + return 440*2**(mel*(max_octave-min_octave)/24.+min_octave/24.) + +# def mel_scale(n_mels,hz,f_min=160.,f_max=8000.,pt=True): +# #take absolute hz, return abs mel +# # return (torch.log2(hz/440)+31/24)*24*n_mels/126 +# m_min = 2595.0 * np.log10(1.0 + (f_min / 700.0)) +# m_max = 2595.0 * np.log10(1.0 + (f_max / 700.0)) +# m_min_ = m_min + (m_max-m_min)/(n_mels+1) +# m_max_ = m_max +# if pt: +# return (2595.0 * torch.log10(1.0 + (hz / 700.0))-m_min_)/(m_max_-m_min_)*n_mels +# else: +# return (2595.0 * np.log10(1.0 + (hz / 700.0))-m_min_)/(m_max_-m_min_)*n_mels + +# def inverse_mel_scale(mel,f_min=160.,f_max=8000.,n_mels=64): +# #take normalized mel, return absolute hz +# # return 440*2**(mel*126/24-31/24) +# m_min = 2595.0 * np.log10(1.0 + (f_min / 700.0)) +# m_max = 2595.0 * np.log10(1.0 + (f_max / 700.0)) +# m_min_ = m_min + (m_max-m_min)/(n_mels+1) +# m_max_ = m_max +# return 700.0 * (10**((mel*(m_max_-m_min_) + m_min_)/ 2595.0) - 1.0) + +def ind2hz(ind,n_fft,max_freq=8000.): + #input abs ind, output abs hz + return ind/(1.0*n_fft)*max_freq + +def hz2ind(hz,n_fft,max_freq=8000.): + # input abs hz, output abs ind + return hz/(1.0*max_freq)*n_fft + +def bandwidth_mel(freqs_hz,bandwidth_hz,n_mels): + # input hz bandwidth, output abs bandwidth on mel + bandwidth_upper = freqs_hz+bandwidth_hz/2. + bandwidth_lower = torch.clamp(freqs_hz-bandwidth_hz/2.,min=1) + bandwidth = mel_scale(n_mels,bandwidth_upper) - mel_scale(n_mels,bandwidth_lower) + return bandwidth + +def torch_P2R(radii, angles): + return radii * torch.cos(angles),radii * torch.sin(angles) +def inverse_spec_to_audio(spec,n_fft = 511,win_length = 511,hop_length = 128,power_synth=True): + ''' + generate random phase, then use istft to inverse spec to audio + ''' + window = torch.hann_window(win_length) + angles = torch.randn_like(spec).uniform_(0, np.pi*2)#torch.zeros_like(spec)#torch.randn_like(spec).uniform_(0, np.pi*2) + spec = spec**0.5 if power_synth else spec + spec_complex = torch.stack(torch_P2R(spec, angles),dim=-1) #real and image in same dim + return torchaudio.functional.istft(spec_complex, n_fft=n_fft, window=window, center=True, win_length=win_length, hop_length=hop_length) + +@GENERATORS.register("GeneratorFormant") +class FormantSysth(nn.Module): + def __init__(self, n_mels=64, k=100, wavebased=False,n_fft=256,noise_db=-50,max_db=22.5,dbbased=False,add_bgnoise=True,log10=False,noise_from_data=False,return_wave=False,power_synth=False): + super(FormantSysth, self).__init__() + self.wave_fr = 16e3 + self.spec_fr = 125 + self.n_fft = n_fft + self.noise_db=noise_db + self.max_db = max_db + self.n_mels = n_mels + self.k = k + self.dbbased=dbbased + self.log10 = log10 + self.add_bgnoise = add_bgnoise + self.wavebased=wavebased + self.noise_from_data = noise_from_data + self.linear_scale = wavebased + self.return_wave = return_wave + self.power_synth = power_synth + self.timbre = Parameter(torch.Tensor(1,1,n_mels)) + self.timbre_mapping = nn.Sequential( + ln.Conv1d(1,128,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,2,1), + # nn.Sigmoid(), + ) + self.bgnoise_mapping = nn.Sequential( + ln.Conv2d(2,2,[1,5],padding=[0,2],gain=1,bias=False), + # nn.Sigmoid(), + ) + self.noise_mapping = nn.Sequential( + ln.Conv2d(2,2,[1,5],padding=[0,2],gain=1,bias=False), + # nn.Sigmoid(), + ) + + self.bgnoise_mapping2 = nn.Sequential( + ln.Conv1d(1,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,gain=1,bias=False), + # nn.Sigmoid(), + ) + self.noise_mapping2 = nn.Sequential( + ln.Conv1d(1,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,1,gain=1,bias=False), + # nn.Sigmoid(), + ) + self.prior_exp = np.array([0.4963,0.0745,1.9018]) + self.timbre_parameter = Parameter(torch.Tensor(2)) + self.wave_noise_amplifier = Parameter(torch.Tensor(1)) + self.wave_hamon_amplifier = Parameter(torch.Tensor(1)) + + if noise_from_data: + self.bgnoise_amp = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.bgnoise_amp,1) + else: + self.bgnoise_dist = Parameter(torch.Tensor(1,1,1,self.n_fft if self.wavebased else self.n_mels)) + with torch.no_grad(): + nn.init.constant_(self.bgnoise_dist,1.0) + # self.silient = Parameter(torch.Tensor(1,1,n_mels)) + self.silient = -1 + with torch.no_grad(): + nn.init.constant_(self.timbre,1.0) + nn.init.constant_(self.timbre_parameter[0],7) + nn.init.constant_(self.timbre_parameter[1],0.004) + nn.init.constant_(self.wave_noise_amplifier,1) + nn.init.constant_(self.wave_hamon_amplifier,4.) + + # nn.init.constant_(self.silient,-1.0) + +# def formant_mask(self,freq,bandwith,amplitude): +# # freq, bandwith, amplitude: B*formants*time +# freq_cord = torch.arange(self.n_mels) +# time_cord = torch.arange(freq.shape[2]) +# grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) +# grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 +# grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 +# freq = freq.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# bandwith = bandwith.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# amplitude = amplitude.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants +# # masks = amplitude*torch.exp(-0.693*(grid_freq-freq)**2/(2*(bandwith+0.001)**2)) #B,time,freqchans, formants +# masks = amplitude*torch.exp(-(grid_freq-freq)**2/(2*(bandwith/np.sqrt(2*np.log(2))+0.001)**2)) #B,time,freqchans, formants +# masks = masks.unsqueeze(dim=1) #B,1,time,freqchans, formants +# return masks + + def formant_mask(self,freq_hz,bandwith_hz,amplitude,linear=False, triangle_mask = False,duomask=True, n_formant_noise=1,f0_hz=None,noise=False): + # freq, bandwith, amplitude: B*formants*time + freq_cord = torch.arange(self.n_fft if linear else self.n_mels) + time_cord = torch.arange(freq_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq_hz = ind2hz(grid_freq,self.n_fft,self.wave_fr/2) if linear else inverse_mel_scale(grid_freq/(self.n_mels*1.0)) + freq_hz = freq_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + bandwith_hz = bandwith_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + amplitude = amplitude.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, formants + if self.power_synth: + amplitude = amplitude + alpha = (2*np.sqrt(2*np.log(np.sqrt(2)))) + if not noise: + # t = torch.arange(int(f0_hz.shape[2]/self.spec_fr*self.wave_fr))/(1.0*self.wave_fr) #in second + # t = t.unsqueeze(dim=0).unsqueeze(dim=0) #1, 1, time + k = (torch.arange(self.k)+1).reshape([1,self.k,1]) + # f0_hz_interp = F.interpolate(f0_hz,t.shape[-1],mode='linear',align_corners=False) #Bx1xT + # bandwith_hz_interp = F.interpolate(bandwith_hz.permute(0,2,3,1),[bandwith_hz.shape[-1],t.shape[-1]],mode='bilinear',align_corners=False).permute(0,3,1,2) #Bx1xT + # freq_hz_interp = F.interpolate(freq_hz.permute(0,2,3,1),[freq_hz.shape[-1],t.shape[-1]],mode='bilinear',align_corners=False).permute(0,3,1,2) #Bx1xT + k_f0 = k*f0_hz #BxkxT + freq_range = (-torch.sign(k_f0-7800)*0.5+0.5) #BxkxT + k_f0 = k_f0.permute([0,2,1]).unsqueeze(-1) #BxTxkx1 + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) if self.wavebased else amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) #B,time,freqchans, formants + # amplitude_interp = F.interpolate(amplitude.permute(0,2,3,1),[amplitude.shape[-1],t.shape[-1]],mode='bilinear',align_corners=False).permute(0,3,1,2) #Bx1xT + hamonic_dist = (amplitude*(torch.exp(-((k_f0-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2))+1E-6).sqrt()).sum(-1).permute([0,2,1]) #BxkxT + hamonic_dist = (hamonic_dist*freq_range)/((((hamonic_dist*freq_range)**2).sum(1,keepdim=True)+1E-10).sqrt()+1E-10) # sum_k(hamonic_dist**2) = 1 + hamonic_dist = F.interpolate(hamonic_dist,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr),mode = 'linear',align_corners=False) + return hamonic_dist # B,k,T + else: + masks = amplitude*(torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2))+1E-6).sqrt() #B,time,freqchans, formants + masks = masks.sum(-1) #B,time,freqchans + masks = masks/((((masks**2).sum(-1,keepdim=True)/self.n_fft)+1E-10).sqrt()+1E-10) + masks = masks.unsqueeze(dim=1) #B,1,time,freqchans + return masks #B,1,time,freqchans + + # if self.wavebased: + # if triangle_mask: + # if duomask: + # # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-(0.693*(grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]+0.01)**2)) + # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-((grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]/alpha+0.01)**2)) + # bw = bandwith_hz[...,-n_formant_noise:] + # masks_noise = F.relu(amplitude[...,-n_formant_noise:] * (1 - (1-1/np.sqrt(2))*2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())) + # # masks_noise = amplitude[...,-n_formant_noise:] * (1 - 2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz[...,-n_formant_noise:])/(bw+0.01)-0.5)*0.5+0.5) + # masks = torch.cat([masks_hamon,masks_noise],dim=-1) + # else: + # masks = F.relu(amplitude * (1 - (1-1/np.sqrt(2))*2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())) + # else: + # # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) + # if self.power_synth: + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + # else: + # # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + # masks = amplitude*(torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2))+1E-6).sqrt() + # else: + # if triangle_mask: + # if duomask: + # # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-(0.693*(grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]+0.01)**2)) + # masks_hamon = amplitude[...,:-n_formant_noise]*torch.exp(-((grid_freq_hz-freq_hz[...,:-n_formant_noise]))**2/(2*(bandwith_hz[...,:-n_formant_noise]/alpha+0.01)**2)) + # bw = bandwith_hz[...,-n_formant_noise:] + # masks_noise = F.relu(amplitude[...,-n_formant_noise:] * (1 - (1-1/np.sqrt(2))*2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())) + # # masks_noise = amplitude[...,-n_formant_noise:] * (1 - 2/(bw+0.01)*(grid_freq_hz-freq_hz[...,-n_formant_noise:]).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz[...,-n_formant_noise:])/(bw+0.01)-0.5)*0.5+0.5) + # masks = torch.cat([masks_hamon,masks_noise],dim=-1) + # else: + # masks = F.relu(amplitude * (1 - (1-1/np.sqrt(2))*2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())) + # # masks = amplitude * (1 - 2/(bandwith_hz+0.01)*(grid_freq_hz-freq_hz).abs())*(-torch.sign(torch.abs(grid_freq_hz-freq_hz)/(bandwith_hz+0.01)-0.5)*0.5+0.5) + # else: + # # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/alpha+0.01)**2)) + # masks = amplitude*torch.exp(-((grid_freq_hz-freq_hz))**2/(2*(bandwith_hz/(2*np.sqrt(2*np.log(2)))+0.01)**2)) #B,time,freqchans, formants + # masks = amplitude*torch.exp(-(0.693*(grid_freq_hz-freq_hz))**2/(2*(bandwith_hz+0.01)**2)) #B,time,freqchans, formants + + def voicing_wavebased(self,f0_hz): + #f0: B*1*time, hz + t = torch.arange(int(f0_hz.shape[2]/self.spec_fr*self.wave_fr))/(1.0*self.wave_fr) #in second + t = t.unsqueeze(dim=0).unsqueeze(dim=0) #1, 1, time + k = (torch.arange(self.k)+1).reshape([1,self.k,1]) + f0_hz_interp = F.interpolate(f0_hz,t.shape[-1],mode='linear',align_corners=False) + k_f0 = k*f0_hz_interp + k_f0_sum = 2*np.pi*torch.cumsum(k_f0,-1)/(1.0*self.wave_fr) + wave_k = np.sqrt(2)*torch.sin(k_f0_sum) * (-torch.sign(k_f0-7800)*0.5+0.5) + # wave = 0.12*torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-6000)*0.5+0.5) + # wave = 0.09*torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-self.wave_fr/2)*0.5+0.5) + # wave = 0.09*torch.sigmoid(self.wave_hamon_amplifier) * torch.sin(2*np.pi*k_f0*t) * (-torch.sign(k_f0-self.wave_fr/2)*0.5+0.5) + # wave = wave_k.sum(dim=1,keepdim=True) + # wave = F.softplus(self.wave_hamon_amplifier) * wave.sum(dim=1,keepdim=True) + # spec = wave2spec(wave,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=self.dbbased,power=2. if self.power_synth else 1.) + return wave_k #B,k,T + # if self.return_wave: + # return spec,wave_k + # else: + # return spec + + def unvoicing_wavebased(self,f0_hz,bg=False,mapping=True): + # return torch.ones([1,1,f0_hz.shape[2],512]) + # noise = 0.3*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.03*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + if bg: + noise = torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + if mapping: + noise = self.bgnoise_mapping2(noise) + else: + noise = np.sqrt(3.)*(2*torch.rand([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)])-1) + if mapping: + noise = self.noise_mapping2(noise) + # noise = 0.3 * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.3 * F.softplus(self.wave_noise_amplifier) * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + return wave2spec(noise,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=False,power=2. if self.power_synth else 1.) + # return torchaudio.transforms.Spectrogram(self.n_fft*2-1,win_length=self.n_fft*2-1,hop_length=int(self.wave_fr/self.spec_fr),power=2. if self.power_synth else 1.)(noise) + + # def unvoicing_wavebased(self,f0_hz): + # # return torch.ones([1,1,f0_hz.shape[2],512]) + # # noise = 0.3*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # noise = 0.1*torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # # noise = 0.3 * torch.sigmoid(self.wave_noise_amplifier) * torch.randn([1,1,int(f0_hz.shape[2]/self.spec_fr*self.wave_fr)]) + # return wave2spec(noise,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=False) + + def voicing_linear(self,f0_hz,bandwith=2.5): + #f0: B*1*time, hz + freq_cord = torch.arange(self.n_fft) + time_cord = torch.arange(f0_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + f0_hz = f0_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, 1 + f0_hz = f0_hz.repeat([1,1,1,self.k]) #B,time,1, self.k + f0_hz = f0_hz*(torch.arange(self.k)+1).reshape([1,1,1,self.k]) + # bandwith=4 + # bandwith_lower = torch.clamp(f0-bandwith/2,min=1) + # bandwith_upper = f0+bandwith/2 + # bandwith = mel_scale(self.n_mels,bandwith_upper) - mel_scale(self.n_mels,bandwith_lower) + f0 = hz2ind(f0_hz,self.n_fft) + + # hamonics = torch.exp(-(grid_freq-f0)**2/(2*sigma**2)) #gaussian + # hamonics = (1-((grid_freq_hz-f0_hz)/(2*bandwith_hz/2))**2)*(-torch.sign(torch.abs(grid_freq_hz-f0_hz)/(2*bandwith_hz)-0.5)*0.5+0.5) #welch + freq_cord_reshape = freq_cord.reshape([1,1,1,self.n_fft]) + hamonics = (1 - 2/bandwith*(grid_freq-f0).abs())*(-torch.sign(torch.abs(grid_freq-f0)/(bandwith)-0.5)*0.5+0.5) #triangular + # hamonics = (1-((grid_freq-f0)/(bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(2.5*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(2.5*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(3*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(3*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = torch.cos(np.pi*torch.abs(grid_freq-f0)/(4*bandwith))**2*(-torch.sign(torch.abs(grid_freq-f0)/(4*bandwith)-0.5)*0.5+0.5) #hanning + # hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) # B,1,T,F + # condition = (torch.sign(freq_cord_reshape-switch)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-slop*(freq_cord_reshape-switch)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)-torch.abs(self.prior_exp_parameter[2])) * torch.exp(-torch.abs(self.prior_exp_parameter[1])*freq_cord.reshape([1,1,1,self.n_mels])) + torch.abs(self.prior_exp_parameter[2]) # B,1,T,F + + # timbre_parameter = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]).unsqueeze(1) + # condition = (torch.sign(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) * F.softplus(timbre_parameter[...,2:3]) + timbre_parameter[...,3:4] # B,1,T,F + # timbre = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]) + # hamonics = (hamonics.sum(dim=-1)*timbre).unsqueeze(dim=1) # B,1,T,F + # hamonics = (hamonics.sum(dim=-1)*self.timbre).unsqueeze(dim=1) # B,1,T,F + + hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) + # hamonics = 180*F.softplus(self.wave_hamon_amplifier)*(hamonics.sum(dim=-1)).unsqueeze(dim=1) + + return hamonics + + def voicing(self,f0_hz): + #f0: B*1*time, hz + freq_cord = torch.arange(self.n_mels) + time_cord = torch.arange(f0_hz.shape[2]) + grid_time,grid_freq = torch.meshgrid(time_cord,freq_cord) + grid_time = grid_time.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq = grid_freq.unsqueeze(dim=0).unsqueeze(dim=-1) #B,time,freq, 1 + grid_freq_hz = inverse_mel_scale(grid_freq/(self.n_mels*1.0)) + f0_hz = f0_hz.permute([0,2,1]).unsqueeze(dim=-2) #B,time,1, 1 + f0_hz = f0_hz.repeat([1,1,1,self.k]) #B,time,1, self.k + f0_hz = f0_hz*(torch.arange(self.k)+1).reshape([1,1,1,self.k]) + if self.log10: + f0_mel = mel_scale(self.n_mels,f0_hz) + band_low_hz = inverse_mel_scale((f0_mel-1)/(self.n_mels*1.0),n_mels = self.n_mels) + band_up_hz = inverse_mel_scale((f0_mel+1)/(self.n_mels*1.0),n_mels = self.n_mels) + bandwith_hz = band_up_hz-band_low_hz + band_low_mel = mel_scale(self.n_mels,band_low_hz) + band_up_mel = mel_scale(self.n_mels,band_up_hz) + bandwith = band_up_mel-band_low_mel + else: + bandwith_hz = 24.7*(f0_hz*4.37/1000+1) + bandwith = bandwidth_mel(f0_hz,bandwith_hz,self.n_mels) + # bandwith_lower = torch.clamp(f0-bandwith/2,min=1) + # bandwith_upper = f0+bandwith/2 + # bandwith = mel_scale(self.n_mels,bandwith_upper) - mel_scale(self.n_mels,bandwith_lower) + f0 = mel_scale(self.n_mels,f0_hz) + switch = mel_scale(self.n_mels,torch.abs(self.timbre_parameter[0])*f0_hz[...,0]).unsqueeze(1) + slop = (torch.abs(self.timbre_parameter[1])*f0_hz[...,0]).unsqueeze(1) + freq_cord_reshape = freq_cord.reshape([1,1,1,self.n_mels]) + if not self.dbbased: + # sigma = bandwith/(np.sqrt(2*np.log(2))); + sigma = bandwith/(2*np.sqrt(2*np.log(2))); + hamonics = torch.exp(-(grid_freq-f0)**2/(2*sigma**2)) #gaussian + # hamonics = (1-((grid_freq_hz-f0_hz)/(2*bandwith_hz/2))**2)*(-torch.sign(torch.abs(grid_freq_hz-f0_hz)/(2*bandwith_hz)-0.5)*0.5+0.5) #welch + else: + # # hamonics = (1-((grid_freq-f0)/(1.75*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(1.75*bandwith)-0.5)*0.5+0.5) #welch + hamonics = (1-((grid_freq-f0)/(2.5*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(2.5*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = (1-((grid_freq-f0)/(3*bandwith/2))**2)*(-torch.sign(torch.abs(grid_freq-f0)/(3*bandwith)-0.5)*0.5+0.5) #welch + # hamonics = torch.cos(np.pi*torch.abs(grid_freq-f0)/(4*bandwith))**2*(-torch.sign(torch.abs(grid_freq-f0)/(4*bandwith)-0.5)*0.5+0.5) #hanning + # hamonics = (hamonics.sum(dim=-1)).unsqueeze(dim=1) # B,1,T,F + # condition = (torch.sign(freq_cord_reshape-switch)*0.5+0.5) + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-slop*(freq_cord_reshape-switch)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)-torch.abs(self.prior_exp_parameter[2])) * torch.exp(-torch.abs(self.prior_exp_parameter[1])*freq_cord.reshape([1,1,1,self.n_mels])) + torch.abs(self.prior_exp_parameter[2]) # B,1,T,F + + timbre_parameter = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]).unsqueeze(1) + condition = (torch.sign(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*0.5+0.5) + amp = F.softplus(self.wave_hamon_amplifier) if self.dbbased else 180*F.softplus(self.wave_hamon_amplifier) + hamonics = amp * ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) # B,1,T,F + # hamonics = ((hamonics.sum(dim=-1)).unsqueeze(dim=1)) * (1+ (torch.exp(-0.01*torch.sigmoid(timbre_parameter[...,1:2])*(freq_cord_reshape-torch.sigmoid(timbre_parameter[...,0:1])*self.n_mels)*condition)-1)*condition) * F.softplus(timbre_parameter[...,2:3]) + timbre_parameter[...,3:4] # B,1,T,F + # timbre = self.timbre_mapping(f0_hz[...,0,0].unsqueeze(1)).permute([0,2,1]) + # hamonics = (hamonics.sum(dim=-1)*timbre).unsqueeze(dim=1) # B,1,T,F + # hamonics = (hamonics.sum(dim=-1)*self.timbre).unsqueeze(dim=1) # B,1,T,F + # return F.softplus(self.wave_hamon_amplifier)*hamonics + return hamonics + + def unvoicing(self,f0,bg=False,mapping=True): + # return (0.25*torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels]))+1 + rnd = torch.randn([f0.shape[0],2,f0.shape[2],self.n_fft if self.wavebased else self.n_mels]) + if mapping: + rnd = self.bgnoise_mapping(rnd) if bg else self.noise_mapping(rnd) + real = rnd[:,0:1] + img = rnd[:,1:2] + if self.dbbased: + return (2*torchaudio.transforms.AmplitudeToDB()(torch.sqrt(real**2 + img**2+1E-10))+80).clamp(min=0)/35 + # return (2*torchaudio.transforms.AmplitudeToDB()(F.softplus(self.wave_noise_amplifier)*torch.sqrt(torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2 + torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2))+80).clamp(min=0)/35 + else: + # return torch.ones([f0.shape[0],1,f0.shape[2],self.n_mels]) + return 180*F.softplus(self.wave_noise_amplifier) * torch.sqrt(real**2 + img**2+1E-10) + # return F.softplus(self.wave_noise_amplifier)*torch.sqrt(torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2 + torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels])**2) + + # return (F.softplus(self.wave_noise_amplifier)) * (0.25*torch.randn([f0.shape[0],1,f0.shape[2],self.n_mels]))+1 + # return torch.ones([f0.shape[0],1,f0.shape[2],self.n_mels]) + + def forward(self,components,enable_hamon_excitation=True,enable_noise_excitation=True,enable_bgnoise=True): + # f0: B*1*T, amplitudes: B*2(voicing,unvoicing)*T, freq_formants,bandwidth_formants,amplitude_formants: B*formants*T + amplitudes = components['amplitudes'].unsqueeze(dim=-1) + amplitudes_h = components['amplitudes_h'].unsqueeze(dim=-1) + loudness = components['loudness'].unsqueeze(dim=-1) + f0_hz = components['f0_hz'] + # import pdb;pdb.set_trace() + if self.wavebased: + # self.hamonics = 1800*F.softplus(self.wave_hamon_amplifier)*self.voicing_linear(f0_hz) + # self.noise = 180*self.unvoicing(f0_hz,bg=False,mapping=False) + # self.bgnoise = 18*self.unvoicing(f0_hz,bg=True,mapping=False) + # import pdb;pdb.set_trace() + self.hamonics_wave = self.voicing_wavebased(f0_hz) + self.noise = self.unvoicing_wavebased(f0_hz,bg=False,mapping=False) + self.bgnoise = self.unvoicing_wavebased(f0_hz,bg=True) + else: + self.hamonics = self.voicing(f0_hz) + self.noise = self.unvoicing(f0_hz,bg=False) + self.bgnoise = self.unvoicing(f0_hz,bg=True) + # freq_formants = components['freq_formants']*self.n_mels + # bandwidth_formants = components['bandwidth_formants']*self.n_mels + # excitation = amplitudes[:,0:1]*hamonics + # excitation = loudness*(amplitudes[:,0:1]*hamonics) + + self.excitation_noise = loudness*(amplitudes[:,-1:])*self.noise if self.power_synth else (loudness*amplitudes[:,-1:]+1E-10).sqrt()*self.noise + duomask = components['freq_formants_noise_hz'].shape[1]>components['freq_formants_hamon_hz'].shape[1] + n_formant_noise = (components['freq_formants_noise_hz'].shape[1]-components['freq_formants_hamon_hz'].shape[1]) if duomask else components['freq_formants_noise_hz'].shape[1] + self.hamonic_dist = self.formant_mask(components['freq_formants_hamon_hz'],components['bandwidth_formants_hamon_hz'],components['amplitude_formants_hamon'],linear = self.linear_scale,f0_hz = f0_hz) + self.mask_noise = self.formant_mask(components['freq_formants_noise_hz'],components['bandwidth_formants_noise_hz'],components['amplitude_formants_noise'],linear = self.linear_scale,triangle_mask=False if self.wavebased else True,duomask=duomask,n_formant_noise=n_formant_noise,f0_hz = f0_hz,noise=True) + # self.mask_hamon = self.formant_mask(components['freq_formants_hamon']*self.n_mels,components['bandwidth_formants_hamon'],components['amplitude_formants_hamon']) + # self.mask_noise = self.formant_mask(components['freq_formants_noise']*self.n_mels,components['bandwidth_formants_noise'],components['amplitude_formants_noise']) + if self.power_synth: + self.excitation_hamon_wave = F.interpolate(loudness[...,-1]*amplitudes[:,0:1][...,-1],self.hamonics_wave.shape[-1],mode='linear',align_corners=False)*self.hamonics_wave + else: + self.excitation_hamon_wave = F.interpolate((loudness[...,-1]*amplitudes[:,0:1][...,-1]+1E-10).sqrt(),self.hamonics_wave.shape[-1],mode='linear',align_corners=False)*self.hamonics_wave + self.hamonics_wave_ = (self.excitation_hamon_wave*self.hamonic_dist).sum(1,keepdim=True) + + bgdist = F.softplus(self.bgnoise_amp)*self.noise_dist if self.noise_from_data else F.softplus(self.bgnoise_dist) + # if self.power_synth: + # self.excitation_hamon = loudness*(amplitudes[:,0:1])*self.hamonics + # else: + # self.excitation_hamon = loudness*amplitudes[:,0:1]*self.hamonics + # import pdb;pdb.set_trace() + self.noise_excitation = self.excitation_noise*self.mask_noise + + self.noise_excitation_wave = 2*inverse_spec_to_audio(self.noise_excitation.squeeze(1).permute(0,2,1),n_fft=self.n_fft*2-1,power_synth=self.power_synth) + self.noise_excitation_wave = F.pad(self.noise_excitation_wave,[0,self.hamonics_wave_.shape[2]-self.noise_excitation_wave.shape[1]]) + self.noise_excitation_wave = self.noise_excitation_wave.unsqueeze(1) + self.rec_wave_clean = self.noise_excitation_wave+self.hamonics_wave_ + + if self.add_bgnoise and enable_bgnoise: + self.bgn = bgdist*self.bgnoise*0.0003 + self.bgn_wave = 2*inverse_spec_to_audio(self.bgn.squeeze(1).permute(0,2,1),n_fft=self.n_fft*2-1,power_synth=self.power_synth) + self.bgn_wave = F.pad(self.bgn_wave,[0,self.hamonics_wave_.shape[2]-self.bgn_wave.shape[1]]) + self.bgn_wave = self.bgn_wave.unsqueeze(1) + self.rec_wave = self.rec_wave_clean + self.bgn_wave + else: + self.rec_wave = self.rec_wave_clean + + speech = wave2spec(self.rec_wave,self.n_fft,self.wave_fr,self.spec_fr,self.noise_db,self.max_db,to_db=True,power=2 if self.power_synth else 1) + # if self.wavebased: + # # import pdb; pdb.set_trace() + # bgn = bgdist*self.bgnoise*0.0003 if (self.add_bgnoise and enable_bgnoise) else 0 + # speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.noise_excitation if enable_noise_excitation else 0) + bgn + # # speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.excitation_noise*self.mask_noise_sum if enable_noise_excitation else 0) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + # # speech = speech if self.power_synth else speech**2 + # speech = (torchaudio.transforms.AmplitudeToDB()(speech).clamp(min=self.noise_db)-self.noise_db)/(self.max_db-self.noise_db)*2-1 + # else: + # # speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + (self.silient*torch.ones(self.mask_hamon_sum.shape) if self.dbbased else 0) + # speech = ((self.excitation_hamon*self.mask_hamon_sum) if enable_hamon_excitation else torch.zeros(self.excitation_hamon.shape)) + (self.noise_excitation if enable_noise_excitation else 0) + (((bgdist*self.bgnoise*0.0003) if not self.dbbased else (2*torchaudio.transforms.AmplitudeToDB()(bgdist*0.0003)/35. + self.bgnoise)) if (self.add_bgnoise and enable_bgnoise) else 0) + (self.silient*torch.ones(self.mask_hamon_sum.shape) if self.dbbased else 0) + # # speech = self.excitation_hamon*self.mask_hamon_sum + (self.excitation_noise*self.mask_noise_sum if enable_noise_excitation else 0) + self.silient*torch.ones(self.mask_hamon_sum.shape) + # if not self.dbbased: + # speech = db(speech) + + + # import pdb;pdb.set_trace() + if self.return_wave: + return speech,self.rec_wave_clean + else: + return speech + +@ENCODERS.register("EncoderFormant") +class FormantEncoder(nn.Module): + def __init__(self, n_mels=64, n_formants=4,n_formants_noise=2,min_octave=-31,max_octave=96,wavebased=False,hop_length=128,n_fft=256,noise_db=-50,max_db=22.5,broud=True,power_synth=False): + super(FormantEncoder, self).__init__() + self.wavebased = wavebased + self.n_mels = n_mels + self.n_formants = n_formants + self.n_formants_noise = n_formants_noise + self.min_octave = min_octave + self.max_octave = max_octave + self.noise_db = noise_db + self.max_db = max_db + self.broud = broud + self.hop_length = hop_length + self.n_fft = n_fft + self.power_synth=power_synth + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,700.,2700.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + # self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,500.,500.]).reshape([1,3,1]) #freq difference + + self.formant_bandwitdh_bias = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_thres = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_bias,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + nn.init.constant_(self.formant_bandwitdh_thres,0) + + # self.formant_freq_limits = torch.cumsum(self.formant_freq_limits_diff,dim=0) + # self.formant_freq_limits_mel = torch.cat([torch.tensor([0.]),mel_scale(n_mels,self.formant_freq_limits)/n_mels]) + # self.formant_freq_limits_mel_diff = torch.reshape(self.formant_freq_limits_mel[1:]-self.formant_freq_limits_mel[:-1],[1,3,1]) + if broud: + if wavebased: + self.conv1_narrow = ln.Conv1d(n_fft,64,3,1,1) + self.conv1_mel = ln.Conv1d(128,64,3,1,1) + self.norm1_mel = nn.GroupNorm(32,64) + self.conv2_mel = ln.Conv1d(64,128,3,1,1) + self.norm2_mel = nn.GroupNorm(32,128) + self.conv_fundementals_mel = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals_mel = nn.GroupNorm(32,128) + self.f0_drop_mel = nn.Dropout() + else: + self.conv1_narrow = ln.Conv1d(n_mels,64,3,1,1) + self.norm1_narrow = nn.GroupNorm(32,64) + self.conv2_narrow = ln.Conv1d(64,128,3,1,1) + self.norm2_narrow = nn.GroupNorm(32,128) + + self.conv_fundementals_narrow = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals_narrow = nn.GroupNorm(32,128) + self.f0_drop_narrow = nn.Dropout() + if wavebased: + self.conv_f0_narrow = ln.Conv1d(256,1,1,1,0) + else: + self.conv_f0_narrow = ln.Conv1d(128,1,1,1,0) + + self.conv_amplitudes_narrow = ln.Conv1d(128,2,1,1,0) + self.conv_amplitudes_h_narrow = ln.Conv1d(128,2,1,1,0) + + if wavebased: + self.conv1 = ln.Conv1d(n_fft,64,3,1,1) + else: + self.conv1 = ln.Conv1d(n_mels,64,3,1,1) + self.norm1 = nn.GroupNorm(32,64) + self.conv2 = ln.Conv1d(64,128,3,1,1) + self.norm2 = nn.GroupNorm(32,128) + + self.conv_fundementals = ln.Conv1d(128,128,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,128) + self.f0_drop = nn.Dropout() + self.conv_f0 = ln.Conv1d(128,1,1,1,0) + + self.conv_amplitudes = ln.Conv1d(128,2,1,1,0) + self.conv_amplitudes_h = ln.Conv1d(128,2,1,1,0) + # self.conv_loudness = nn.Sequential(ln.Conv1d(n_fft if wavebased else n_mels,128,1,1,0), + # nn.LeakyReLU(0.2), + # ln.Conv1d(128,128,1,1,0), + # nn.LeakyReLU(0.2), + # ln.Conv1d(128,1,1,1,0,bias_initial=0.5),) + self.conv_loudness = nn.Sequential(ln.Conv1d(n_fft if wavebased else n_mels,128,1,1,0), + nn.LeakyReLU(0.2), + ln.Conv1d(128,128,1,1,0), + nn.LeakyReLU(0.2), + ln.Conv1d(128,1,1,1,0,bias_initial=-9.),) + # self.conv_loudness_power = nn.Sequential(ln.Conv1d(1,128,1,1,0), + # nn.LeakyReLU(0.2), + # ln.Conv1d(128,128,1,1,0), + # nn.LeakyReLU(0.2), + # ln.Conv1d(128,1,1,1,0,bias_initial=-9.),) + + if self.broud: + self.conv_formants = ln.Conv1d(128,128,3,1,1) + else: + self.conv_formants = ln.Conv1d(128,128,3,1,1) + self.norm_formants = nn.GroupNorm(32,128) + self.conv_formants_freqs = ln.Conv1d(128,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(128,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(128,n_formants,1,1,0) + + self.conv_formants_freqs_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + self.conv_formants_bandwidth_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + self.conv_formants_amplitude_noise = ln.Conv1d(128,self.n_formants_noise,1,1,0) + + self.amplifier = Parameter(torch.Tensor(1)) + self.bias = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.amplifier,1.0) + nn.init.constant_(self.bias,0.) + + def forward(self,x,x_denoise=None,duomask=False,noise_level = None,x_amp=None): + x = x.squeeze(dim=1).permute(0,2,1) #B * f * T + if x_denoise is not None: + x_denoise = x_denoise.squeeze(dim=1).permute(0,2,1) + # x_denoise_amp = amplitude(x_denoise,self.noise_db,self.max_db) + # import pdb; pdb.set_trace() + if x_amp is None: + x_amp = amplitude(x,self.noise_db,self.max_db,trim_noise=True) + else: + x_amp = x_amp.squeeze(dim=1).permute(0,2,1) + hann_win = torch.hann_window(5,periodic=False).reshape([1,1,5,1]) + x_smooth = F.conv2d(x.unsqueeze(1).transpose(-2,-1),hann_win,padding=[2,0]).transpose(-2,-1).squeeze(1) + # loudness = F.softplus(self.amplifier)*(torch.mean(x_denoise_amp,dim=1,keepdim=True)) + # loudness = F.relu(F.softplus(self.amplifier)*(torch.mean(x_amp,dim=1,keepdim=True)-noise_level*0.0003)) + # loudness = torch.mean((x*0.5+0.5) if x_denoise is None else (x_denoise*0.5+0.5),dim=1,keepdim=True) + # loudness = F.softplus(self.amplifier)*(loudness) + # loudness = F.softplus(self.amplifier)*torch.mean(x_amp,dim=1,keepdim=True) + # loudness = F.softplus(self.amplifier)*F.relu(loudness - F.softplus(self.bias)) + if self.power_synth: + loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x_smooth)) + else: + # loudness = F.softplus(self.amplifier)*F.relu((x_amp**2).sum(1,keepdim=True)/self.hop_length/383.-self.bias) + loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x_smooth)) # compute power loudness + # loudness = F.softplus(self.amplifier)*F.relu((x_amp**2).sum(1,keepdim=True)-self.bias) # compute power loudness + # loudness = F.softplus((1. if self.wavebased else 1.0)*self.conv_loudness(x)) + # loudness = F.relu(self.conv_loudness(x)) + + # if not self.power_synth: + # loudness = loudness.sqrt() + + if self.broud: + x_narrow = x + x_narrow = F.leaky_relu(self.norm1_narrow(self.conv1_narrow(x_narrow)),0.2) + x_common_narrow = F.leaky_relu(self.norm2_narrow(self.conv2_narrow(x_narrow)),0.2) + amplitudes = F.softmax(self.conv_amplitudes_narrow(x_common_narrow),dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h_narrow(x_common_narrow),dim=1) + x_fundementals_narrow = self.f0_drop_narrow(F.leaky_relu(self.norm_fundementals_narrow(self.conv_fundementals_narrow(x_common_narrow)),0.2)) + + x_amp = amplitude(x.unsqueeze(1),self.noise_db,self.max_db).transpose(-2,-1) + x_mel = to_db(torchaudio.transforms.MelScale(f_max=8000,n_stft=self.n_fft)(x_amp.transpose(-2,-1)),self.noise_db,self.max_db).squeeze(1) + x = F.leaky_relu(self.norm1_mel(self.conv1_mel(x_mel)),0.2) + x_common_mel = F.leaky_relu(self.norm2_mel(self.conv2_mel(x)),0.2) + x_fundementals_mel = self.f0_drop_mel(F.leaky_relu(self.norm_fundementals_mel(self.conv_fundementals_mel(x_common_mel)),0.2)) + + f0_hz = torch.sigmoid(self.conv_f0_narrow(torch.cat([x_fundementals_narrow,x_fundementals_mel],dim=1))) * 120 + 180 # 180hz < f0 < 300 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + hann_win = torch.hann_window(21,periodic=False).reshape([1,1,21,1]) + x = to_db(F.conv2d(x_amp,hann_win,padding=[10,0]).transpose(-2,-1),self.noise_db,self.max_db).squeeze(1) + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x_common = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + + + + else: + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x_common = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + + + # loudness = F.relu(self.conv_loudness(x_common)) + # loudness = F.relu(self.conv_loudness(x_common)) +(10**(self.noise_db/10.-1) if self.wavebased else 0) + amplitudes = F.softmax(self.conv_amplitudes(x_common),dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common),dim=1) + + # x_fundementals = F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2) + x_fundementals = self.f0_drop(F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2)) + # f0 in mel: + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) + # f0 = F.tanh(self.conv_f0(x_fundementals)) * (16/64)*(self.n_mels/64) # 72hz < f0 < 446 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + + # f0 in hz: + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 120 + 180 # 180hz < f0 < 300 hz + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 332 + 88 # 88hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 120 + 180 # 180hz < f0 < 300 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 528 + 88 # 88hz < f0 < 616 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 302 + 118 # 118hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 240 + 180 # 180hz < f0 < 420 hz + # f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 260 + 160 # 160hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)) + # # relative freq: + # formants_freqs_hz = formants_freqs*(self.formant_freq_limits_diff[:,:self.n_formants]-self.formant_freq_limits_diff_low[:,:self.n_formants])+self.formant_freq_limits_diff_low[:,:self.n_formants] + # # formants_freqs_hz = formants_freqs*6839 + # formants_freqs_hz = torch.cumsum(formants_freqs_hz,dim=1) + + # abs freq: + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + # formants_freqs_hz = formants_freqs*6839 + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + + # formants_freqs = formants_freqs + f0 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) *6839 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 150 + # formants_bandwidth_hz = 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+100) + # formants_bandwidth_hz = (torch.sigmoid(self.conv_formants_bandwidth(x_formants))) * (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100) #good for spec based method + # formants_bandwidth_hz = ((3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.0125*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + # formants_bandwidth_hz = (2**(torch.tanh(self.formant_bandwitdh_slop))*0.001*torch.relu(formants_freqs_hz-4000*torch.sigmoid(self.formant_bandwitdh_thres))+375*2**(torch.tanh(self.formant_bandwitdh_bias))) + # formants_bandwidth_hz = torch.exp(0.4*torch.tanh(self.conv_formants_bandwidth(x_formants))) * (0.00625*torch.relu(formants_freqs_hz-0)+375) + # formants_bandwidth_hz = (100*(torch.tanh(self.conv_formants_bandwidth(x_formants))) + (0.035*torch.relu(formants_freqs_hz-950)+250)) if self.wavebased else ((3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.0125*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + # formants_bandwidth_hz = (100*(torch.tanh(self.conv_formants_bandwidth(x_formants))) + (0.035*torch.relu(formants_freqs_hz-950)+250)) if self.wavebased else ((torch.sigmoid(self.conv_formants_bandwidth(x_formants))+0.2) * (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100))#good for spec based method + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * (0.075*3*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+(2*torch.sigmoid(self.formant_bandwitdh_ratio)+1)*50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+50) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + # formants_bandwidth_upper = formants_freqs_hz+formants_bandwidth_hz/2 + # formants_bandwidth_lower = torch.clamp(formants_freqs_hz-formants_bandwidth_hz/2,min=1) + # formants_bandwidth = (mel_scale(self.n_mels,formants_bandwidth_upper) - mel_scale(self.n_mels,formants_bandwidth_lower))/(self.n_mels*1.0) + # formants_amplitude = F.softmax(torch.cumsum(-F.relu(self.conv_formants_amplitude(x_formants)),dim=1),dim=1) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)) + # # relative freq: + # formants_freqs_hz = formants_freqs*(self.formant_freq_limits_diff[:,:self.n_formants]-self.formant_freq_limits_diff_low[:,:self.n_formants])+self.formant_freq_limits_diff_low[:,:self.n_formants] + # # formants_freqs_hz = formants_freqs*6839 + # formants_freqs_hz = torch.cumsum(formants_freqs_hz,dim=1) + + # abs freq: + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + if duomask: + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + # formants_freqs_hz = formants_freqs*6839 + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + + # formants_freqs = formants_freqs + f0 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) *6839 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 150 + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 8000 + 2000 + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants) + # formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 8000 + 2000 #2000-10000 + # formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 2000 #0-2000 + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + if duomask: + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 4000 + 1000 + # formants_bandwidth_hz_noise = torch.sigmoid(self.conv_formants_bandwidth_noise(x_formants)) * 4000 + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.relu(formants_freqs_hz-1000)+50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * (0.075*3*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+(2*torch.sigmoid(self.formant_bandwitdh_ratio)+1)*50) + # formants_bandwidth_hz = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) * 3*torch.sigmoid(self.formant_bandwitdh_ratio)*(0.075*torch.sigmoid(self.formant_bandwitdh_slop)*torch.relu(formants_freqs_hz-1000)+50) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + # formants_bandwidth_upper = formants_freqs_hz+formants_bandwidth_hz/2 + # formants_bandwidth_lower = torch.clamp(formants_freqs_hz-formants_bandwidth_hz/2,min=1) + # formants_bandwidth = (mel_scale(self.n_mels,formants_bandwidth_upper) - mel_scale(self.n_mels,formants_bandwidth_lower))/(self.n_mels*1.0) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants) + if duomask: + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + +class FromECoG(nn.Module): + def __init__(self, outputs,residual=False,shape='3D'): + super().__init__() + self.residual=residual + if shape =='3D': + self.from_ecog = ln.Conv3d(1, outputs, [9,1,1], 1, [4,0,0]) + else: + self.from_ecog = ln.Conv2d(1, outputs, [9,1], 1, [4,0]) + + def forward(self, x): + x = self.from_ecog(x) + if not self.residual: + x = F.leaky_relu(x, 0.2) + return x + +class ECoGMappingBlock(nn.Module): + def __init__(self, inputs, outputs, kernel_size,dilation=1,fused_scale=True,residual=False,resample=[],pool=None,shape='3D'): + super(ECoGMappingBlock, self).__init__() + self.residual = residual + self.pool = pool + self.inputs_resample = resample + self.dim_missmatch = (inputs!=outputs) + self.resample = resample + if not self.resample: + self.resample=1 + self.padding = list(np.array(dilation)*(np.array(kernel_size)-1)//2) + if shape=='2D': + conv=ln.Conv2d + maxpool = nn.MaxPool2d + avgpool = nn.AvgPool2d + if shape=='3D': + conv=ln.Conv3d + maxpool = nn.MaxPool3d + avgpool = nn.AvgPool3d + # self.padding = [dilation[i]*(kernel_size[i]-1)//2 for i in range(len(dilation))] + if residual: + self.norm1 = nn.GroupNorm(min(inputs,32),inputs) + else: + self.norm1 = nn.GroupNorm(min(outputs,32),outputs) + if pool is None: + self.conv1 = conv(inputs, outputs, kernel_size, self.resample, self.padding, dilation=dilation, bias=False) + else: + self.conv1 = conv(inputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.pool1 = maxpool(self.resample,self.resample) if self.pool=='Max' else avgpool(self.resample,self.resample) + if self.inputs_resample or self.dim_missmatch: + if pool is None: + self.convskip = conv(inputs, outputs, kernel_size, self.resample, self.padding, dilation=dilation, bias=False) + else: + self.convskip = conv(inputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.poolskip = maxpool(self.resample,self.resample) if self.pool=='Max' else avgpool(self.resample,self.resample) + + self.conv2 = conv(outputs, outputs, kernel_size, 1, self.padding, dilation=dilation, bias=False) + self.norm2 = nn.GroupNorm(min(outputs,32),outputs) + + def forward(self,x): + if self.residual: + x = F.leaky_relu(self.norm1(x),0.2) + if self.inputs_resample or self.dim_missmatch: + # x_skip = F.avg_pool3d(x,self.resample,self.resample) + x_skip = self.convskip(x) + if self.pool is not None: + x_skip = self.poolskip(x_skip) + else: + x_skip = x + x = F.leaky_relu(self.norm2(self.conv1(x)),0.2) + if self.pool is not None: + x = self.poolskip(x) + x = self.conv2(x) + x = x_skip + x + else: + x = F.leaky_relu(self.norm1(self.conv1(x)),0.2) + x = F.leaky_relu(self.norm2(self.conv2(x)),0.2) + return x + + + +@ECOG_ENCODER.register("ECoGMappingBottleneck") +class ECoGMapping_Bottleneck(nn.Module): + def __init__(self,n_mels,n_formants,n_formants_noise=1): + super(ECoGMapping_Bottleneck, self).__init__() + self.n_formants = n_formants + self.n_mels = n_mels + self.n_formants_noise = n_formants_noise + + self.formant_freq_limits_diff = torch.tensor([950.,2450.,2100.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_diff_low = torch.tensor([300.,300.,0.]).reshape([1,3,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,2800.,3400.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2800.,3400]).reshape([1,4,1]) #freq difference + + # self.formant_freq_limits_abs = torch.tensor([950.,3300.,3600.,4700.]).reshape([1,4,1]) #freq difference + # self.formant_freq_limits_abs_low = torch.tensor([300.,600.,2700.,3400]).reshape([1,4,1]) #freq difference + + self.formant_freq_limits_abs = torch.tensor([950.,3400.,3800.,5000.,6000.,7000.]).reshape([1,6,1]) #freq difference + self.formant_freq_limits_abs_low = torch.tensor([300.,700.,1800.,3400,5000.,6000.]).reshape([1,6,1]) #freq difference + + # self.formant_freq_limits_abs_noise = torch.tensor([7000.]).reshape([1,1,1]) #freq difference + self.formant_freq_limits_abs_noise = torch.tensor([8000.,7000.,7000.]).reshape([1,3,1]) #freq difference + self.formant_freq_limits_abs_noise_low = torch.tensor([4000.,3000.,3000.]).reshape([1,3,1]) #freq difference + + self.formant_bandwitdh_ratio = Parameter(torch.Tensor(1)) + self.formant_bandwitdh_slop = Parameter(torch.Tensor(1)) + with torch.no_grad(): + nn.init.constant_(self.formant_bandwitdh_ratio,0) + nn.init.constant_(self.formant_bandwitdh_slop,0) + + + self.from_ecog = FromECoG(16,residual=True) + self.conv1 = ECoGMappingBlock(16,32,[5,1,1],residual=True,resample = [2,1,1],pool='MAX') + self.conv2 = ECoGMappingBlock(32,64,[3,1,1],residual=True,resample = [2,1,1],pool='MAX') + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv3d(64,1,[3,1,1],1,[1,0,0]) + self.conv3 = ECoGMappingBlock(64,128,[3,3,3],residual=True,resample = [2,2,2],pool='MAX') + self.conv4 = ECoGMappingBlock(128,256,[3,3,3],residual=True,resample = [2,2,2],pool='MAX') + self.norm = nn.GroupNorm(32,256) + self.conv5 = ln.Conv1d(256,256,3,1,1) + self.norm2 = nn.GroupNorm(32,256) + self.conv6 = ln.ConvTranspose1d(256, 128, 3, 2, 1, transform_kernel=True) + self.norm3 = nn.GroupNorm(32,128) + self.conv7 = ln.ConvTranspose1d(128, 64, 3, 2, 1, transform_kernel=True) + self.norm4 = nn.GroupNorm(32,64) + self.conv8 = ln.ConvTranspose1d(64, 32, 3, 2, 1, transform_kernel=True) + self.norm5 = nn.GroupNorm(32,32) + self.conv9 = ln.ConvTranspose1d(32, 32, 3, 2, 1, transform_kernel=True) + self.norm6 = nn.GroupNorm(32,32) + + self.conv_fundementals = ln.Conv1d(32,32,3,1,1) + self.norm_fundementals = nn.GroupNorm(32,32) + self.f0_drop = nn.Dropout() + self.conv_f0 = ln.Conv1d(32,1,1,1,0) + self.conv_amplitudes = ln.Conv1d(32,2,1,1,0) + self.conv_amplitudes_h = ln.Conv1d(32,2,1,1,0) + self.conv_loudness = ln.Conv1d(32,1,1,1,0,bias_initial=-9.) + + self.conv_formants = ln.Conv1d(32,32,3,1,1) + self.norm_formants = nn.GroupNorm(32,32) + self.conv_formants_freqs = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_bandwidth = ln.Conv1d(32,n_formants,1,1,0) + self.conv_formants_amplitude = ln.Conv1d(32,n_formants,1,1,0) + + self.conv_formants_freqs_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_bandwidth_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + self.conv_formants_amplitude_noise = ln.Conv1d(32,n_formants_noise,1,1,0) + + + def forward(self,ecog,mask_prior,mni): + x_common_all = [] + for d in range(len(ecog)): + x = ecog[d] + x = x.reshape([-1,1,x.shape[1],15,15]) + mask_prior_d = mask_prior[d].reshape(-1,1,1,15,15) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,4:] + if mask_prior is not None: + mask = mask*mask_prior_d + x = x[:,:,4:] + x = x*mask + x = self.conv3(x) + x = self.conv4(x) + x = x.max(-1)[0].max(-1)[0] + x = self.conv5(F.leaky_relu(self.norm(x),0.2)) + x = self.conv6(F.leaky_relu(self.norm2(x),0.2)) + x = self.conv7(F.leaky_relu(self.norm3(x),0.2)) + x = self.conv8(F.leaky_relu(self.norm4(x),0.2)) + x = self.conv9(F.leaky_relu(self.norm5(x),0.2)) + x_common = F.leaky_relu(self.norm6(x),0.2) + x_common_all += [x_common] + + x_common = torch.cat(x_common_all,dim=0) + loudness = F.softplus(self.conv_loudness(x_common)) + amplitudes = F.softmax(self.conv_amplitudes(x_common),dim=1) + amplitudes_h = F.softmax(self.conv_amplitudes_h(x_common),dim=1) + + # x_fundementals = F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2) + x_fundementals = self.f0_drop(F.leaky_relu(self.norm_fundementals(self.conv_fundementals(x_common)),0.2)) + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) + # f0 = F.tanh(self.conv_f0(x_fundementals)) * (16/64)*(self.n_mels/64) # 72hz < f0 < 446 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + + # f0 in hz: + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * 528 + 88 # 88hz < f0 < 616 hz + f0_hz = torch.sigmoid(self.conv_f0(x_fundementals)) * 332 + 88 # 88hz < f0 < 420 hz + f0 = torch.clamp(mel_scale(self.n_mels,f0_hz)/(self.n_mels*1.0),min=0.0001) + + x_formants = F.leaky_relu(self.norm_formants(self.conv_formants(x_common)),0.2) + formants_freqs = torch.sigmoid(self.conv_formants_freqs(x_formants)) + # formants_freqs = torch.cumsum(formants_freqs,dim=1) + # formants_freqs = formants_freqs + + # abs freq + formants_freqs_hz = formants_freqs*(self.formant_freq_limits_abs[:,:self.n_formants]-self.formant_freq_limits_abs_low[:,:self.n_formants])+self.formant_freq_limits_abs_low[:,:self.n_formants] + formants_freqs = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz)/(self.n_mels*1.0),min=0) + + # formants_freqs = formants_freqs + f0 + # formants_bandwidth = torch.sigmoid(self.conv_formants_bandwidth(x_formants)) + # formants_bandwidth_hz = (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100) + formants_bandwidth_hz = 0.65*(0.00625*torch.relu(formants_freqs_hz)+375) + # formants_bandwidth_hz = (torch.sigmoid(self.conv_formants_bandwidth(x_formants))) * (3*torch.sigmoid(self.formant_bandwitdh_ratio))*(0.075*torch.relu(formants_freqs_hz-1000)+100) + formants_bandwidth = bandwidth_mel(formants_freqs_hz,formants_bandwidth_hz,self.n_mels) + formants_amplitude_logit = self.conv_formants_amplitude(x_formants) + formants_amplitude = F.softmax(formants_amplitude_logit,dim=1) + + formants_freqs_noise = torch.sigmoid(self.conv_formants_freqs_noise(x_formants)) + formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:self.n_formants_noise]-self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise])+self.formant_freq_limits_abs_noise_low[:,:self.n_formants_noise] + # formants_freqs_hz_noise = formants_freqs_noise*(self.formant_freq_limits_abs_noise[:,:1]-self.formant_freq_limits_abs_noise_low[:,:1])+self.formant_freq_limits_abs_noise_low[:,:1] + formants_freqs_hz_noise = torch.cat([formants_freqs_hz,formants_freqs_hz_noise],dim=1) + formants_freqs_noise = torch.clamp(mel_scale(self.n_mels,formants_freqs_hz_noise)/(self.n_mels*1.0),min=0) + # formants_bandwidth_hz_noise = F.relu(self.conv_formants_bandwidth_noise(x_formants)) * 8000 + 2000 + # formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + # formants_amplitude_noise = F.softmax(self.conv_formants_amplitude_noise(x_formants),dim=1) + formants_bandwidth_hz_noise = self.conv_formants_bandwidth_noise(x_formants) + formants_bandwidth_hz_noise_1 = F.softplus(formants_bandwidth_hz_noise[:,:1]) * 2344 + 586 #2000-10000 + formants_bandwidth_hz_noise_2 = torch.sigmoid(formants_bandwidth_hz_noise[:,1:]) * 586 #0-2000 + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz_noise_1,formants_bandwidth_hz_noise_2],dim=1) + formants_bandwidth_hz_noise = torch.cat([formants_bandwidth_hz,formants_bandwidth_hz_noise],dim=1) + formants_bandwidth_noise = bandwidth_mel(formants_freqs_hz_noise,formants_bandwidth_hz_noise,self.n_mels) + formants_amplitude_noise_logit = self.conv_formants_amplitude_noise(x_formants) + formants_amplitude_noise_logit = torch.cat([formants_amplitude_logit,formants_amplitude_noise_logit],dim=1) + formants_amplitude_noise = F.softmax(formants_amplitude_noise_logit,dim=1) + + components = { 'f0':f0, + 'f0_hz':f0_hz, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'amplitudes_h':amplitudes_h, + 'freq_formants_hamon':formants_freqs, + 'bandwidth_formants_hamon':formants_bandwidth, + 'freq_formants_hamon_hz':formants_freqs_hz, + 'bandwidth_formants_hamon_hz':formants_bandwidth_hz, + 'amplitude_formants_hamon':formants_amplitude, + 'freq_formants_noise':formants_freqs_noise, + 'bandwidth_formants_noise':formants_bandwidth_noise, + 'freq_formants_noise_hz':formants_freqs_hz_noise, + 'bandwidth_formants_noise_hz':formants_bandwidth_hz_noise, + 'amplitude_formants_noise':formants_amplitude_noise, + } + return components + + +class BackBone(nn.Module): + def __init__(self,attentional_mask=True): + super(BackBone, self).__init__() + self.attentional_mask = attentional_mask + self.from_ecog = FromECoG(16,residual=True,shape='2D') + self.conv1 = ECoGMappingBlock(16,32,[5,1],residual=True,resample = [1,1],shape='2D') + self.conv2 = ECoGMappingBlock(32,64,[3,1],residual=True,resample = [1,1],shape='2D') + self.norm_mask = nn.GroupNorm(32,64) + self.mask = ln.Conv2d(64,1,[3,1],1,[1,0]) + + def forward(self,ecog): + x_common_all = [] + mask_all=[] + for d in range(len(ecog)): + x = ecog[d] + x = x.unsqueeze(1) + x = self.from_ecog(x) + x = self.conv1(x) + x = self.conv2(x) + if self.attentional_mask: + mask = F.relu(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + mask = mask[:,:,16:] + x = x[:,:,16:] + mask_all +=[mask] + else: + # mask = torch.sigmoid(self.mask(F.leaky_relu(self.norm_mask(x),0.2))) + # mask = mask[:,:,16:] + x = x[:,:,16:] + # x = x*mask + + x_common_all +=[x] + + x_common = torch.cat(x_common_all,dim=0) + if self.attentional_mask: + mask = torch.cat(mask_all,dim=0) + return x_common,mask.squeeze(1) if self.attentional_mask else None + +class ECoGEncoderFormantHeads(nn.Module): + def __init__(self,inputs,n_mels,n_formants): + super(ECoGEncoderFormantHeads,self).__init__() + self.n_mels = n_mels + self.f0 = ln.Conv1d(inputs,1,1) + self.loudness = ln.Conv1d(inputs,1,1) + self.amplitudes = ln.Conv1d(inputs,2,1) + self.freq_formants = ln.Conv1d(inputs,n_formants,1) + self.bandwidth_formants = ln.Conv1d(inputs,n_formants,1) + self.amplitude_formants = ln.Conv1d(inputs,n_formants,1) + + def forward(self,x): + loudness = F.relu(self.loudness(x)) + f0 = torch.sigmoid(self.f0(x)) * (15/64)*(self.n_mels/64) # 179hz < f0 < 420 hz + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (22/64)*(self.n_mels/64) - (16/64)*(self.n_mels/64)# 72hz < f0 < 253 hz, human voice + # f0 = torch.sigmoid(self.conv_f0(x_fundementals)) * (11/64)*(self.n_mels/64) - (-2/64)*(self.n_mels/64)# 160hz < f0 < 300 hz, female voice + amplitudes = F.softmax(self.amplitudes(x),dim=1) + freq_formants = torch.sigmoid(self.freq_formants(x)) + freq_formants = torch.cumsum(freq_formants,dim=1) + bandwidth_formants = torch.sigmoid(self.bandwidth_formants(x)) + amplitude_formants = F.softmax(self.amplitude_formants(x),dim=1) + return {'f0':f0, + 'loudness':loudness, + 'amplitudes':amplitudes, + 'freq_formants':freq_formants, + 'bandwidth_formants':bandwidth_formants, + 'amplitude_formants':amplitude_formants,} + +@ECOG_ENCODER.register("ECoGMappingTransformer") +class ECoGMapping_Transformer(nn.Module): + def __init__(self,n_mels,n_formants,SeqLen=128,hidden_dim=256,dim_feedforward=256,encoder_only=False,attentional_mask=False,n_heads=1,non_local=False): + super(ECoGMapping_Transformer, self).__init__() + self.n_mels = n_mels, + self.n_formant = n_formants, + self.encoder_only = encoder_only, + self.attentional_mask = attentional_mask, + self.backbone = BackBone(attentional_mask=attentional_mask) + self.position_encoding = build_position_encoding(SeqLen,hidden_dim,'MNI') + self.input_proj = ln.Conv2d(64, hidden_dim, kernel_size=1) + if non_local: + Transformer = TransformerNL + else: + Transformer = TransformerTS + self.transformer = Transformer(d_model=hidden_dim, nhead=n_heads, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=dim_feedforward, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False,encoder_only = encoder_only) + self.output_proj = ECoGEncoderFormantHeads(hidden_dim,n_mels,n_formants) + self.query_embed = nn.Embedding(SeqLen, hidden_dim) + + def forward(self,x,mask_prior,mni): + features,mask = self.backbone(x) + pos = self.position_encoding(mni) + hs = self.transformer(self.input_proj(features), mask if self.attentional_mask else None, self.query_embed.weight, pos) + if not self.encoder_only: + hs,encoded = hs + out = self.output_proj(hs) + else: + _,encoded = hs + encoded = encoded.max(-1)[0] + out = self.output_proj(encoded) + return out + + + diff --git a/registry.py b/registry.py index f6e25be6..cc588777 100644 --- a/registry.py +++ b/registry.py @@ -5,3 +5,4 @@ GENERATORS = Registry() MAPPINGS = Registry() DISCRIMINATORS = Registry() +ECOG_ENCODER = Registry() diff --git a/run.s b/run.s new file mode 100644 index 00000000..93e71065 --- /dev/null +++ b/run.s @@ -0,0 +1,18 @@ +#!/bin/bash +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=2 +#SBATCH --gres=gpu:p40:2 +#SBATCH --time=60:00:00 +#SBATCH --mem=64GB +#SBATCH --job-name=myTest +#SBATCH --output=slurm_%j.out + +cd $SCRATCH/neural_decoding/code/cnn/ALAE/ + +module purge +module load cudnn/10.0v7.6.2.24 +module load cuda/10.0.130 +source $HOME/python3.7/bin/activate +export PYTHONPATH=$PYTHONPATH:$(pwd) +python train_formant.py diff --git a/scheduler.py b/scheduler.py index d92284a7..4bebc582 100644 --- a/scheduler.py +++ b/scheduler.py @@ -63,9 +63,9 @@ def get_lr(self): alpha = float(self.last_epoch) / self.warmup_iters warmup_factor = self.warmup_factor * (1 - alpha) + alpha return [ - base_lr[self.lod] + np.maximum(base_lr[self.lod] * warmup_factor - * self.gamma ** bisect_right(self.milestones, self.last_epoch) + * self.gamma ** bisect_right(self.milestones, self.last_epoch),1e-4) # * float(self.batch_size) # / float(self.reference_batch_size) for base_lr in self.base_lrs diff --git a/tracker.py b/tracker.py index 3fcbe0a4..99633cb7 100644 --- a/tracker.py +++ b/tracker.py @@ -51,19 +51,20 @@ def __iadd__(self, value): def reset(self): self.values = [] - def mean(self): + def mean(self,dim=[]): with torch.no_grad(): if len(self.values) == 0: return 0.0 - return float(torch.cat(self.values).mean().item()) + return torch.cat(self.values).mean(dim=dim).numpy() class LossTracker: - def __init__(self, output_folder='.'): + def __init__(self, output_folder='.',test=False): self.tracks = OrderedDict() self.epochs = [] self.means_over_epochs = OrderedDict() self.output_folder = output_folder + self.filename = 'log_test.csv' if test else 'log_train.csv' def update(self, d): for k, v in d.items(): @@ -87,17 +88,17 @@ def register_means(self, epoch): for key in self.means_over_epochs.keys(): if key in self.tracks: value = self.tracks[key] - self.means_over_epochs[key].append(value.mean()) + self.means_over_epochs[key].append(value.mean(dim=0)) value.reset() else: self.means_over_epochs[key].append(None) - - with open(os.path.join(self.output_folder, 'log.csv'), mode='w') as csv_file: - fieldnames = ['epoch'] + list(self.tracks.keys()) + with open(os.path.join(self.output_folder, self.filename), mode='w') as csv_file: + fieldnames = ['epoch'] + [key+str(i) for key in list(self.tracks.keys()) for i in range(self.means_over_epochs[key][0].size)] + # fieldnames = ['epoch'] + [list(self.tracks.keys())] writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) writer.writerow(fieldnames) for i in range(len(self.epochs)): - writer.writerow([self.epochs[i]] + [self.means_over_epochs[x][i] for x in self.tracks.keys()]) + writer.writerow([self.epochs[i]] + [self.means_over_epochs[x][i][j] if self.means_over_epochs[x][i].size>1 else self.means_over_epochs[x][i] for x in self.tracks.keys() for j in range(self.means_over_epochs[x][i].size) ]) def __str__(self): result = "" diff --git a/train_alae.py b/train_alae.py index 9f1f30d2..1cbd3834 100644 --- a/train_alae.py +++ b/train_alae.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +import json import torch.utils.data from torchvision.utils import save_image from net import * @@ -21,7 +21,7 @@ from checkpointer import Checkpointer from scheduler import ComboMultiStepLR from custom_adam import LREQAdam -from dataloader import * +from dataloader_ecog import * from tqdm import tqdm from dlutils.pytorch import count_parameters import dlutils.pytorch.count_parameters as count_param_override @@ -31,69 +31,125 @@ from defaults import get_cfg_defaults import lod_driver from PIL import Image - - -def save_sample(lod2batch, tracker, sample, samplez, x, logger, model, cfg, encoder_optimizer, decoder_optimizer): +import numpy as np +from torch import autograd +from ECoGDataSet import concate_batch +def save_sample(lod2batch, tracker, sample, samplez, samplez_global, x, logger, model, cfg, encoder_optimizer, decoder_optimizer,filename=None,ecog=None,mask_prior=None,mode='test'): os.makedirs('results', exist_ok=True) - logger.info('\n[%d/%d] - ptime: %.2f, %s, blend: %.3f, lr: %.12f, %.12f, max mem: %f",' % ( (lod2batch.current_epoch + 1), cfg.TRAIN.TRAIN_EPOCHS, lod2batch.per_epoch_ptime, str(tracker), lod2batch.get_blend_factor(), encoder_optimizer.param_groups[0]['lr'], decoder_optimizer.param_groups[0]['lr'], torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)) - + # sample = sample.transpose(-2,-1) with torch.no_grad(): model.eval() - sample = sample[:lod2batch.get_per_GPU_batch_size()] - samplez = samplez[:lod2batch.get_per_GPU_batch_size()] + # sample = sample[:lod2batch.get_per_GPU_batch_size()] + # samplez = samplez[:lod2batch.get_per_GPU_batch_size()] needed_resolution = model.decoder.layer_to_resolution[lod2batch.lod] - sample_in = sample - while sample_in.shape[2] > needed_resolution: - sample_in = F.avg_pool2d(sample_in, 2, 2) - assert sample_in.shape[2] == needed_resolution - - blend_factor = lod2batch.get_blend_factor() - if lod2batch.in_transition: - needed_resolution_prev = model.decoder.layer_to_resolution[lod2batch.lod - 1] - sample_in_prev = F.avg_pool2d(sample_in, 2, 2) - sample_in_prev_2x = F.interpolate(sample_in_prev, needed_resolution) - sample_in = sample_in * blend_factor + sample_in_prev_2x * (1.0 - blend_factor) - - Z, _ = model.encode(sample_in, lod2batch.lod, blend_factor) - - if cfg.MODEL.Z_REGRESSION: - Z = model.mapping_fl(Z[:, 0]) - else: - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) - - rec1 = model.decoder(Z, lod2batch.lod, blend_factor, noise=False) - rec2 = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) - - # rec1 = F.interpolate(rec1, sample.shape[2]) - # rec2 = F.interpolate(rec2, sample.shape[2]) - # sample_in = F.interpolate(sample_in, sample.shape[2]) - - Z = model.mapping_fl(samplez) - g_rec = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) - # g_rec = F.interpolate(g_rec, sample.shape[2]) - - resultsample = torch.cat([sample_in, rec1, rec2, g_rec], dim=0) - - @utils.async_func + sample_in_all = torch.tensor([]) + rec1_all = torch.tensor([]) + rec2_all = torch.tensor([]) + g_rec_all = torch.tensor([]) + for i in range(0,sample.shape[0],9): + sample_in = sample[i:np.minimum(i+9,sample.shape[0])] + if ecog is not None: + ecog_in = [ecog[j][i:np.minimum(i+9,sample.shape[0])] for j in range(len(ecog))] + mask_prior_in = [mask_prior[j][i:np.minimum(i+9,sample.shape[0])] for j in range(len(mask_prior))] + x_in = x[i:np.minimum(i+9,sample.shape[0])] + samplez_in = samplez[i:np.minimum(i+9,sample.shape[0])] + samplez_global_in = samplez_global[i:np.minimum(i+9,sample.shape[0])] + while sample_in.shape[2] > needed_resolution: + sample_in = F.avg_pool2d(sample_in, 2, 2) + assert sample_in.shape[2] == needed_resolution + + blend_factor = lod2batch.get_blend_factor() + if lod2batch.in_transition: + needed_resolution_prev = model.decoder.layer_to_resolution[lod2batch.lod - 1] + sample_in_prev = F.avg_pool2d(sample_in, 2, 2) + sample_in_prev_2x = F.interpolate(sample_in_prev, scale_factor=2) + sample_in = sample_in * blend_factor + sample_in_prev_2x * (1.0 - blend_factor) + + Z, _ = model.encode(sample_in, lod2batch.lod, blend_factor) + if cfg.MODEL.TEMPORAL_W and cfg.MODEL.GLOBAL_W: + Z,Z_global = Z + if cfg.MODEL.Z_REGRESSION: + Z = model.mapping_fl(Z[:, 0]) + else: + if cfg.MODEL.TEMPORAL_W and cfg.MODEL.GLOBAL_W: + Z = Z.repeat(1, model.mapping_fl.num_layers, 1,1) + Z_global = Z_global.repeat(1, model.mapping_fl.num_layers, 1) + Z = (Z, Z_global) + else: + if cfg.MODEL.TEMPORAL_W: + Z = Z.repeat(1, model.mapping_fl.num_layers, 1,1) + else: + Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + + rec1 = model.decoder(Z, lod2batch.lod, blend_factor, noise=False) + rec2 = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) + + # rec1 = F.interpolate(rec1, sample.shape[2]) + # rec2 = F.interpolate(rec2, sample.shape[2]) + # sample_in = F.interpolate(sample_in, sample.shape[2]) + + if ecog is not None: + Z = model.ecog_encoder(ecog = ecog_in, mask_prior = mask_prior_in) + if cfg.MODEL.TEMPORAL_W and cfg.MODEL.GLOBAL_W: + Z, Z_global = Z + Z = Z.view(Z.shape[0], 1, Z.shape[1],Z.shape[2]).repeat(1, model.mapping_fl.num_layers, 1, 1) + Z_global = Z_global.view(Z_global.shape[0], 1, Z_global.shape[1]).repeat(1, model.mapping_fl.num_layers, 1) + Z = (Z,Z_global) + else: + if cfg.MODEL.TEMPORAL_W: + Z = Z.view(Z.shape[0], 1, Z.shape[1],Z.shape[2]).repeat(1, model.mapping_fl.num_layers, 1, 1) + else: + Z = Z.view(Z.shape[0], 1, Z.shape[1]).repeat(1, model.mapping_fl.num_layers, 1) + else: + Z = model.mapping_fl(samplez_in,samplez_global_in) + g_rec = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) + + # g_rec = model.generate(lod2batch.lod, blend_factor, count=ecog_in[0].shape[0], z=samplez_in.detach(), z_global=samplez_global_in, noise=True,return_styles=False,ecog=ecog_in,mask_prior=mask_prior_in) + + + # g_rec = F.interpolate(g_rec, sample.shape[2]) + sample_in_all = torch.cat([sample_in_all,sample_in],dim=0) + rec1_all = torch.cat([rec1_all,rec1],dim=0) + rec2_all = torch.cat([rec2_all,rec2],dim=0) + g_rec_all = torch.cat([g_rec_all,g_rec],dim=0) + + print(mode+' suploss is',torch.mean((g_rec_all-sample_in_all).abs())) + resultsample = torch.cat([sample_in_all, rec1_all, rec2_all, g_rec_all], dim=0) + if cfg.DATASET.BCTS: + resultsample = resultsample.transpose(-2,-1) + + # @utils.async_func def save_pic(x_rec): - tracker.register_means(lod2batch.current_epoch + lod2batch.iteration * 1.0 / lod2batch.get_dataset_size()) - tracker.plot() + if mode=='test': + tracker.register_means(lod2batch.current_epoch + lod2batch.iteration * 1.0 / lod2batch.get_dataset_size()) + # tracker.plot() result_sample = x_rec * 0.5 + 0.5 result_sample = result_sample.cpu() - f = os.path.join(cfg.OUTPUT_DIR, - 'sample_%d_%d.jpg' % ( - lod2batch.current_epoch + 1, - lod2batch.iteration // 1000) - ) + if filename: + f =filename + else: + if mode == 'test': + f = os.path.join(cfg.OUTPUT_DIR, + 'sample_%d_%d.jpg' % ( + lod2batch.current_epoch + 1, + lod2batch.iteration // 1000) + ) + else: + f = os.path.join(cfg.OUTPUT_DIR, + 'sample_train_%d_%d.jpg' % ( + lod2batch.current_epoch + 1, + lod2batch.iteration // 1000) + ) print("Saved to %s" % f) - save_image(result_sample, f, nrow=min(32, lod2batch.get_per_GPU_batch_size())) + # save_image(result_sample, f, nrow=min(32, lod2batch.get_per_GPU_batch_size())) + save_image(result_sample, f, nrow=x_rec.shape[0]//4) save_pic(resultsample) @@ -111,7 +167,32 @@ def train(cfg, logger, local_rank, world_size, distributed): channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, - z_regression=cfg.MODEL.Z_REGRESSION + ecog_encoder=cfg.MODEL.MAPPING_FROM_ECOG, + z_regression=cfg.MODEL.Z_REGRESSION, + average_w = cfg.MODEL.AVERAGE_W, + temporal_w = cfg.MODEL.TEMPORAL_W, + global_w = cfg.MODEL.GLOBAL_W, + temporal_global_cat = cfg.MODEL.TEMPORAL_GLOBAL_CAT, + spec_chans = cfg.DATASET.SPEC_CHANS, + temporal_samples = cfg.DATASET.TEMPORAL_SAMPLES, + init_zeros = cfg.MODEL.TEMPORAL_W, + residual = cfg.MODEL.RESIDUAL, + w_classifier = cfg.MODEL.W_CLASSIFIER, + uniq_words = cfg.MODEL.UNIQ_WORDS, + attention = cfg.MODEL.ATTENTION, + cycle = cfg.MODEL.CYCLE, + w_weight = cfg.TRAIN.W_WEIGHT, + cycle_weight=cfg.TRAIN.CYCLE_WEIGHT, + attentional_style=cfg.MODEL.ATTENTIONAL_STYLE, + heads = cfg.MODEL.HEADS, + suploss_on_ecog = cfg.MODEL.SUPLOSS_ON_ECOGF, + less_temporal_feature = cfg.MODEL.LESS_TEMPORAL_FEATURE, + ppl_weight=cfg.MODEL.PPL_WEIGHT, + ppl_global_weight=cfg.MODEL.PPL_GLOBAL_WEIGHT, + ppld_weight=cfg.MODEL.PPLD_WEIGHT, + ppld_global_weight=cfg.MODEL.PPLD_GLOBAL_WEIGHT, + common_z = cfg.MODEL.COMMON_Z, + with_ecog = cfg.MODEL.ECOG, ) model.cuda(local_rank) model.train() @@ -128,11 +209,37 @@ def train(cfg, logger, local_rank, world_size, distributed): channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, - z_regression=cfg.MODEL.Z_REGRESSION) + ecog_encoder=cfg.MODEL.MAPPING_FROM_ECOG, + z_regression=cfg.MODEL.Z_REGRESSION, + average_w = cfg.MODEL.AVERAGE_W, + spec_chans = cfg.DATASET.SPEC_CHANS, + temporal_samples = cfg.DATASET.TEMPORAL_SAMPLES, + temporal_w = cfg.MODEL.TEMPORAL_W, + global_w = cfg.MODEL.GLOBAL_W, + temporal_global_cat = cfg.MODEL.TEMPORAL_GLOBAL_CAT, + init_zeros = cfg.MODEL.TEMPORAL_W, + residual = cfg.MODEL.RESIDUAL, + w_classifier = cfg.MODEL.W_CLASSIFIER, + uniq_words = cfg.MODEL.UNIQ_WORDS, + attention = cfg.MODEL.ATTENTION, + cycle = cfg.MODEL.CYCLE, + w_weight = cfg.TRAIN.W_WEIGHT, + cycle_weight=cfg.TRAIN.CYCLE_WEIGHT, + attentional_style=cfg.MODEL.ATTENTIONAL_STYLE, + heads = cfg.MODEL.HEADS, + suploss_on_ecog = cfg.MODEL.SUPLOSS_ON_ECOGF, + less_temporal_feature = cfg.MODEL.LESS_TEMPORAL_FEATURE, + ppl_weight=cfg.MODEL.PPL_WEIGHT, + ppl_global_weight=cfg.MODEL.PPL_GLOBAL_WEIGHT, + ppld_weight=cfg.MODEL.PPLD_WEIGHT, + ppld_global_weight=cfg.MODEL.PPLD_GLOBAL_WEIGHT, + common_z = cfg.MODEL.COMMON_Z, + with_ecog = cfg.MODEL.ECOG, + ) model_s.cuda(local_rank) model_s.eval() model_s.requires_grad_(False) - + # print(model) if distributed: model = nn.parallel.DistributedDataParallel( model, @@ -147,12 +254,24 @@ def train(cfg, logger, local_rank, world_size, distributed): mapping_tl = model.module.mapping_tl mapping_fl = model.module.mapping_fl dlatent_avg = model.module.dlatent_avg + ppl_mean = model.module.ppl_mean + ppl_d_mean = model.module.ppl_d_mean + if hasattr(model,'ecog_encoder'): + ecog_encoder = model.module.ecog_encoder + if cfg.MODEL.W_CLASSIFIER: + mapping_tw = model.module.mapping_tw else: decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg + ppl_mean = model.ppl_mean + if hasattr(model,'ecog_encoder'): + ecog_encoder = model.ecog_encoder + ppl_d_mean = model.ppl_d_mean + if cfg.MODEL.W_CLASSIFIER: + mapping_tw = model.mapping_tw count_param_override.print = lambda a: logger.info(a) @@ -165,15 +284,33 @@ def train(cfg, logger, local_rank, world_size, distributed): arguments = dict() arguments["iteration"] = 0 - decoder_optimizer = LREQAdam([ - {'params': decoder.parameters()}, - {'params': mapping_fl.parameters()} - ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) - - encoder_optimizer = LREQAdam([ - {'params': encoder.parameters()}, - {'params': mapping_tl.parameters()}, - ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + if cfg.MODEL.ECOG: + decoder_optimizer = LREQAdam([ + {'params': decoder.parameters()}, + {'params': ecog_encoder.parameters()} + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + else: + decoder_optimizer = LREQAdam([ + {'params': decoder.parameters()}, + {'params': mapping_fl.parameters()} + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + + if cfg.MODEL.ECOG: + ecog_encoder_optimizer = LREQAdam([ + {'params': ecog_encoder.parameters()} + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + + if cfg.MODEL.W_CLASSIFIER: + encoder_optimizer = LREQAdam([ + {'params': encoder.parameters()}, + {'params': mapping_tl.parameters()}, + {'params': mapping_tw.parameters()}, + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + else: + encoder_optimizer = LREQAdam([ + {'params': encoder.parameters()}, + {'params': mapping_tl.parameters()}, + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers= { @@ -183,47 +320,58 @@ def train(cfg, logger, local_rank, world_size, distributed): milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, gamma=cfg.TRAIN.LEARNING_DECAY_RATE, reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) - model_dict = { 'discriminator': encoder, 'generator': decoder, 'mapping_tl': mapping_tl, 'mapping_fl': mapping_fl, - 'dlatent_avg': dlatent_avg + 'dlatent_avg': dlatent_avg, + 'ppl_mean':ppl_mean, + 'ppl_d_mean':ppl_d_mean, } + if hasattr(model,'ecog_encoder'): + model_dict['ecog_encoder'] = ecog_encoder if local_rank == 0: model_dict['discriminator_s'] = model_s.encoder model_dict['generator_s'] = model_s.decoder model_dict['mapping_tl_s'] = model_s.mapping_tl model_dict['mapping_fl_s'] = model_s.mapping_fl + if hasattr(model_s,'ecog_encoder'): + model_dict['ecog_encoder_s'] = model_s.ecog_encoder tracker = LossTracker(cfg.OUTPUT_DIR) + auxiliary = {'encoder_optimizer': encoder_optimizer, + 'decoder_optimizer': decoder_optimizer, + 'scheduler': scheduler, + 'tracker': tracker + } + if cfg.MODEL.ECOG: + auxiliary['ecog_encoder_optimizer']=ecog_encoder_optimizer checkpointer = Checkpointer(cfg, model_dict, - { - 'encoder_optimizer': encoder_optimizer, - 'decoder_optimizer': decoder_optimizer, - 'scheduler': scheduler, - 'tracker': tracker - }, + auxiliary, logger=logger, save=local_rank == 0) - extra_checkpoint_data = checkpointer.load() + # extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=False,ignore_auxiliary=True,file_name='./training_artifacts/ecog_residual_cycle/model_tmp_lod4.pth') + extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=True,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='./training_artifacts/ecog_residual_latent128_temporal_lesstemporalfeature_noprogressive_HBw_ppl_ppld_localreg_debug/model_tmp_lod6.pth') logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) arguments.update(extra_checkpoint_data) layer_to_resolution = decoder.layer_to_resolution - - dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS) + with open('train_param.json','r') as rfile: + param = json.load(rfile) + # data_param, train_param, test_param = param['Data'], param['Train'], param['Test'] + dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS,param=param) + dataset_test = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS,train=False,param=param) rnd = np.random.RandomState(3456) - latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE) - samplez = torch.tensor(latents).float().cuda() + # latents = rnd.randn(len(dataset_test.dataset), cfg.MODEL.LATENT_SPACE_SIZE) + # samplez = torch.tensor(latents).float().cuda() - lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(dataset) * world_size) + lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(dataset) * world_size, progressive = (not(cfg.FINETUNE.FINETUNE) and cfg.TRAIN.PROGRESSIVE)) if cfg.DATASET.SAMPLES_PATH: path = cfg.DATASET.SAMPLES_PATH @@ -239,10 +387,27 @@ def train(cfg, logger, local_rank, world_size, distributed): x = x[:3] src.append(x) sample = torch.stack(src) + latents = rnd.randn(sample.shape[0], cfg.MODEL.LATENT_SPACE_SIZE) + latents_global = latents if cfg.MODEL.COMMON_Z else rnd.randn(sample.shape[0], cfg.MODEL.LATENT_SPACE_SIZE) + samplez = torch.tensor(latents).float().cuda() + samplez_global = torch.tensor(latents_global).float().cuda() else: - dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32) - sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank)) - sample = (sample / 127.5 - 1.) + dataset_test.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, len(dataset_test.dataset)) + sample_dict_test = next(iter(dataset_test.iterator)) + # sample_dict_test = concate_batch(sample_dict_test) + sample_spec_test = sample_dict_test['spkr_re_batch_all'].to('cuda').float() + latents = rnd.randn(sample_spec_test.shape[0], cfg.MODEL.LATENT_SPACE_SIZE) + latents_global = latents if cfg.MODEL.COMMON_Z else rnd.randn(sample_spec_test.shape[0], cfg.MODEL.LATENT_SPACE_SIZE) + samplez = torch.tensor(latents).float().cuda() + samplez_global = torch.tensor(latents_global).float().cuda() + if cfg.MODEL.ECOG: + ecog_test = [sample_dict_test['ecog_re_batch_all'][i].to('cuda').float() for i in range(len(sample_dict_test['ecog_re_batch_all']))] + mask_prior_test = [sample_dict_test['mask_all'][i].to('cuda').float() for i in range(len(sample_dict_test['mask_all']))] + else: + ecog_test = None + mask_prior_test = None + # sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank)) + # sample = (sample / 127.5 - 1.) lod2batch.set_epoch(scheduler.start_epoch(), [encoder_optimizer, decoder_optimizer]) @@ -260,7 +425,8 @@ def train(cfg, logger, local_rank, world_size, distributed): len(dataset) * world_size)) dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) - batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) + + # batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) @@ -270,69 +436,101 @@ def train(cfg, logger, local_rank, world_size, distributed): epoch_start_time = time.time() i = 0 - for x_orig in tqdm(batches): + for sample_dict_train in tqdm(iter(dataset.iterator)): + # sample_dict_train = concate_batch(sample_dict_train) i += 1 + x_orig = sample_dict_train['spkr_re_batch_all'].to('cuda').float() + words = sample_dict_train['word_batch_all'].to('cuda').long() + words = words.view(words.shape[0]*words.shape[1]) + if cfg.MODEL.ECOG: + ecog = [sample_dict_train['ecog_re_batch_all'][j].to('cuda').float() for j in range(len(sample_dict_train['ecog_re_batch_all']))] + mask_prior = [sample_dict_train['mask_all'][j].to('cuda').float() for j in range(len(sample_dict_train['mask_all']))] + else: + ecog = None + mask_prior = None with torch.no_grad(): - if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): - continue - if need_permute: - x_orig = x_orig.permute(0, 3, 1, 2) - x_orig = (x_orig / 127.5 - 1.) - + # if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): + # continue + # if need_permute: + # x_orig = x_orig.permute(0, 3, 1, 2) + # x_orig = (x_orig / 127.5 - 1.) + x_orig = F.avg_pool2d(x_orig,x_orig.shape[-2]//2**lod2batch.get_lod_power2(),x_orig.shape[-2]//2**lod2batch.get_lod_power2()) + # x_orig = F.interpolate(x_orig, [x_orig.shape[-1]//2**lod2batch.get_lod_power2(),x_orig.shape[-1]//2**lod2batch.get_lod_power2()],mode='bilinear',align_corners=False) blend_factor = lod2batch.get_blend_factor() - needed_resolution = layer_to_resolution[lod2batch.lod] x = x_orig - if lod2batch.in_transition: needed_resolution_prev = layer_to_resolution[lod2batch.lod - 1] x_prev = F.avg_pool2d(x_orig, 2, 2) - x_prev_2x = F.interpolate(x_prev, needed_resolution) + x_prev_2x = F.interpolate(x_prev, scale_factor=2) + # x_prev_2x = F.interpolate(x_prev, needed_resolution,mode='bilinear',align_corners=False) x = x * blend_factor + x_prev_2x * (1.0 - blend_factor) x.requires_grad = True - - encoder_optimizer.zero_grad() - loss_d = model(x, lod2batch.lod, blend_factor, d_train=True, ae=False) - tracker.update(dict(loss_d=loss_d)) - loss_d.backward() - encoder_optimizer.step() - - decoder_optimizer.zero_grad() - loss_g = model(x, lod2batch.lod, blend_factor, d_train=False, ae=False) - tracker.update(dict(loss_g=loss_g)) - loss_g.backward() - decoder_optimizer.step() - - encoder_optimizer.zero_grad() + apply_cycle = cfg.MODEL.CYCLE and True + apply_w_classifier = cfg.MODEL.W_CLASSIFIER and True + apply_gp = True + apply_ppl = cfg.MODEL.APPLY_PPL and True + apply_ppl_d = cfg.MODEL.APPLY_PPL_D and True + apply_encoder_guide = (cfg.FINETUNE.ENCODER_GUIDE or cfg.MODEL.W_SUP) and True + apply_sup = cfg.FINETUNE.SPECSUP + + if not (cfg.FINETUNE.FINETUNE): + encoder_optimizer.zero_grad() + loss_d = model(x, lod2batch.lod, blend_factor, tracker = tracker, d_train=True, ae=False,words=words,apply_w_classifier=apply_w_classifier, apply_gp = apply_gp,apply_ppl_d=apply_ppl_d,ecog=ecog,mask_prior=mask_prior) + (loss_d).backward() + encoder_optimizer.step() + + if cfg.MODEL.ECOG: + ecog_encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() - lae = model(x, lod2batch.lod, blend_factor, d_train=True, ae=True) - tracker.update(dict(lae=lae)) - (lae).backward() - encoder_optimizer.step() - decoder_optimizer.step() + loss_g = model(x, lod2batch.lod, blend_factor, tracker = tracker, d_train=False, ae=False,apply_encoder_guide=apply_encoder_guide,apply_ppl=apply_ppl,ecog=ecog,sup=apply_sup,mask_prior=mask_prior,gan=cfg.MODEL.GAN) + if (cfg.MODEL.ECOG and cfg.MODEL.SUPLOSS_ON_ECOGF) or (cfg.FINETUNE.FINETUNE and cfg.FINETUNE.FIX_GEN ): + loss_g,loss_sup = loss_g + # tracker.update(dict(std_scale=model.decoder.std_each_scale)) + if not (cfg.FINETUNE.FINETUNE and cfg.FINETUNE.FIX_GEN): + (loss_g).backward(retain_graph=True) + decoder_optimizer.step() + if (cfg.MODEL.ECOG and cfg.MODEL.SUPLOSS_ON_ECOGF) or (cfg.FINETUNE.FINETUNE and cfg.FINETUNE.FIX_GEN): + loss_sup.backward() + ecog_encoder_optimizer.step() + + if not cfg.FINETUNE.FINETUNE: + encoder_optimizer.zero_grad() + decoder_optimizer.zero_grad() + lae = model(x, lod2batch.lod, blend_factor, tracker = tracker, d_train=True, ae=True,apply_cycle=apply_cycle,ecog=ecog,mask_prior=mask_prior) + (lae).backward() + encoder_optimizer.step() + decoder_optimizer.step() if local_rank == 0: betta = 0.5 ** (lod2batch.get_batch_size() / (10 * 1000.0)) - model_s.lerp(model, betta) + model_s.lerp(model, betta,w_classifier = cfg.MODEL.W_CLASSIFIER) epoch_end_time = time.time() per_epoch_ptime = epoch_end_time - epoch_start_time - lod_for_saving_model = lod2batch.lod + lod_for_saving_model = lod2batch.lod if cfg.TRAIN.PROGRESSIVE else int(epoch//1) lod2batch.step() if local_rank == 0: if lod2batch.is_time_to_save(): checkpointer.save("model_tmp_intermediate_lod%d" % lod_for_saving_model) if lod2batch.is_time_to_report(): - save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, - decoder_optimizer) + save_sample(lod2batch, tracker, sample_spec_test, samplez, samplez_global, x, logger, model_s, cfg, encoder_optimizer, + decoder_optimizer,ecog=ecog_test,mask_prior=mask_prior_test) + if ecog is not None: + save_sample(lod2batch, tracker, x_orig, samplez, samplez_global, x, logger, model_s, cfg, encoder_optimizer, + decoder_optimizer,ecog=ecog,mask_prior=mask_prior,mode='train') scheduler.step() if local_rank == 0: checkpointer.save("model_tmp_lod%d" % lod_for_saving_model) - save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) + save_sample(lod2batch, tracker, sample_spec_test, samplez, samplez_global, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer, + ecog=ecog_test,mask_prior=mask_prior_test) + if ecog is not None: + save_sample(lod2batch, tracker, x_orig, samplez, samplez_global, x, logger, model_s, cfg, encoder_optimizer, + decoder_optimizer,ecog=ecog,mask_prior=mask_prior,mode='train') logger.info("Training finish!... save training results") if local_rank == 0: @@ -341,5 +539,5 @@ def train(cfg, logger, local_rank, world_size, distributed): if __name__ == "__main__": gpu_count = torch.cuda.device_count() - run(train, get_cfg_defaults(), description='StyleGAN', default_config='configs/ffhq.yaml', + run(train, get_cfg_defaults(), description='StyleGAN', default_config='configs/ecog_style2.yaml', world_size=gpu_count) diff --git a/train_formant_a.py b/train_formant_a.py new file mode 100644 index 00000000..63736fc2 --- /dev/null +++ b/train_formant_a.py @@ -0,0 +1,396 @@ +# Copyright 2019-2020 Stanislav Pidhorskyi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import json +from os import terminal_size +import pdb +import torch.utils.data +from torchvision.utils import save_image +from net_formant import * +import os +import utils +from checkpointer import Checkpointer +from scheduler import ComboMultiStepLR +from custom_adam import LREQAdam +from dataloader_ecog import * +from tqdm import tqdm +from dlutils.pytorch import count_parameters +import dlutils.pytorch.count_parameters as count_param_override +from tracker import LossTracker +from model_formant import Model +from launcher import run +from defaults import get_cfg_defaults +import lod_driver +from PIL import Image +import numpy as np +from torch import autograd +from ECoGDataSet import concate_batch +from formant_systh import save_sample + +from tensorboardX import SummaryWriter + + +import argparse + +parser = argparse.ArgumentParser(description='Process some integers.') + +parser.add_argument('-m', '--modeldir', type=str,default=' ', + help='') +argus = parser.parse_args() + + +def train(cfg, logger, local_rank, world_size, distributed): + writer = SummaryWriter(cfg.OUTPUT_DIR) + torch.cuda.set_device(local_rank) + model = Model( + generator=cfg.MODEL.GENERATOR, + encoder=cfg.MODEL.ENCODER, + ecog_encoder_name=cfg.MODEL.MAPPING_FROM_ECOG, + spec_chans = cfg.DATASET.SPEC_CHANS, + n_formants = cfg.MODEL.N_FORMANTS, + n_formants_noise = cfg.MODEL.N_FORMANTS_NOISE, + n_formants_ecog = cfg.MODEL.N_FORMANTS_ECOG, + wavebased = cfg.MODEL.WAVE_BASED, + n_fft=cfg.MODEL.N_FFT, + noise_db=cfg.MODEL.NOISE_DB, + max_db=cfg.MODEL.MAX_DB, + with_ecog = cfg.MODEL.ECOG, + hidden_dim=cfg.MODEL.TRANSFORMER.HIDDEN_DIM, + dim_feedforward=cfg.MODEL.TRANSFORMER.DIM_FEEDFORWARD, + encoder_only=cfg.MODEL.TRANSFORMER.ENCODER_ONLY, + attentional_mask=cfg.MODEL.TRANSFORMER.ATTENTIONAL_MASK, + n_heads = cfg.MODEL.TRANSFORMER.N_HEADS, + non_local = cfg.MODEL.TRANSFORMER.NON_LOCAL, + do_mel_guide = cfg.MODEL.DO_MEL_GUIDE, + noise_from_data = cfg.MODEL.BGNOISE_FROMDATA, + specsup=cfg.FINETUNE.SPECSUP, + power_synth = cfg.MODEL.POWER_SYNTH, + ) + model.cuda(local_rank) + model.train() + + model_s = Model( + generator=cfg.MODEL.GENERATOR, + encoder=cfg.MODEL.ENCODER, + ecog_encoder_name=cfg.MODEL.MAPPING_FROM_ECOG, + spec_chans = cfg.DATASET.SPEC_CHANS, + n_formants = cfg.MODEL.N_FORMANTS, + n_formants_noise = cfg.MODEL.N_FORMANTS_NOISE, + n_formants_ecog = cfg.MODEL.N_FORMANTS_ECOG, + wavebased = cfg.MODEL.WAVE_BASED, + n_fft=cfg.MODEL.N_FFT, + noise_db=cfg.MODEL.NOISE_DB, + max_db=cfg.MODEL.MAX_DB, + with_ecog = cfg.MODEL.ECOG, + hidden_dim=cfg.MODEL.TRANSFORMER.HIDDEN_DIM, + dim_feedforward=cfg.MODEL.TRANSFORMER.DIM_FEEDFORWARD, + encoder_only=cfg.MODEL.TRANSFORMER.ENCODER_ONLY, + attentional_mask=cfg.MODEL.TRANSFORMER.ATTENTIONAL_MASK, + n_heads = cfg.MODEL.TRANSFORMER.N_HEADS, + non_local = cfg.MODEL.TRANSFORMER.NON_LOCAL, + do_mel_guide = cfg.MODEL.DO_MEL_GUIDE, + noise_from_data = cfg.MODEL.BGNOISE_FROMDATA, + specsup=cfg.FINETUNE.SPECSUP, + power_synth = cfg.MODEL.POWER_SYNTH, + ) + model_s.cuda(local_rank) + model_s.eval() + model_s.requires_grad_(False) + # print(model) + if distributed: + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank], + broadcast_buffers=False, + bucket_cap_mb=25, + find_unused_parameters=True) + model.device_ids = None + decoder = model.module.decoder + encoder = model.module.encoder + if hasattr(model.module,'ecog_encoder'): + ecog_encoder = model.module.ecog_encoder + if hasattr(model.module,'decoder_mel'): + decoder_mel = model.module.decoder_mel + else: + decoder = model.decoder + encoder = model.encoder + if hasattr(model,'ecog_encoder'): + ecog_encoder = model.ecog_encoder + if hasattr(model,'decoder_mel'): + decoder_mel = model.decoder_mel + + count_param_override.print = lambda a: logger.info(a) + + logger.info("Trainable parameters generator:") + count_parameters(decoder) + + logger.info("Trainable parameters discriminator:") + count_parameters(encoder) + + arguments = dict() + arguments["iteration"] = 0 + + if cfg.MODEL.ECOG: + if cfg.MODEL.SUPLOSS_ON_ECOGF: + optimizer = LREQAdam([ + {'params': ecog_encoder.parameters()} + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + else: + optimizer = LREQAdam([ + {'params': ecog_encoder.parameters()}, + {'params': decoder.parameters()}, + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + + else: + if cfg.MODEL.DO_MEL_GUIDE: + optimizer = LREQAdam([ + {'params': encoder.parameters()}, + {'params': decoder.parameters()}, + {'params': decoder_mel.parameters()}, + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + else: + optimizer = LREQAdam([ + {'params': encoder.parameters()}, + {'params': decoder.parameters()} + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + + scheduler = ComboMultiStepLR(optimizers= + {'optimizer': optimizer}, + milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, + gamma=cfg.TRAIN.LEARNING_DECAY_RATE, + reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) + model_dict = { + 'encoder': encoder, + 'generator': decoder, + } + if hasattr(model,'ecog_encoder'): + model_dict['ecog_encoder'] = ecog_encoder + if hasattr(model,'decoder_mel'): + model_dict['decoder_mel'] = decoder_mel + if local_rank == 0: + model_dict['encoder_s'] = model_s.encoder + model_dict['generator_s'] = model_s.decoder + if hasattr(model_s,'ecog_encoder'): + model_dict['ecog_encoder_s'] = model_s.ecog_encoder + if hasattr(model_s,'decoder_mel'): + model_dict['decoder_mel_s'] = model_s.decoder_mel + + tracker = LossTracker(cfg.OUTPUT_DIR) + tracker_test = LossTracker(cfg.OUTPUT_DIR,test=True) + + auxiliary = { + 'optimizer': optimizer, + 'scheduler': scheduler, + 'tracker': tracker, + 'tracker_test':tracker_test, + } + + checkpointer = Checkpointer(cfg, + model_dict, + auxiliary, + logger=logger, + save=local_rank == 0) + + extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=True,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='./training_artifacts/ecog_finetune_3ecogformants_han5_specsup_guidance_hamonicformantsemph/model_epoch23.pth') + logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) + + arguments.update(extra_checkpoint_data) + + + with open('train_param.json','r') as rfile: + param = json.load(rfile) + # data_param, train_param, test_param = param['Data'], param['Train'], param['Test'] + dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS,param=param) + dataset_test = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS,train=False,param=param) + # noise_dist = dataset.noise_dist + noise_dist = torch.from_numpy(dataset.noise_dist).to('cuda').float() + if cfg.MODEL.BGNOISE_FROMDATA: + model_s.noise_dist_init(noise_dist) + model.noise_dist_init(noise_dist) + rnd = np.random.RandomState(3456) + # latents = rnd.randn(len(dataset_test.dataset), cfg.MODEL.LATENT_SPACE_SIZE) + # samplez = torch.tensor(latents).float().cuda() + + + if cfg.DATASET.SAMPLES_PATH: + path = cfg.DATASET.SAMPLES_PATH + src = [] + with torch.no_grad(): + for filename in list(os.listdir(path))[:32]: + img = np.asarray(Image.open(os.path.join(path, filename))) + if img.shape[2] == 4: + img = img[:, :, :3] + im = img.transpose((2, 0, 1)) + x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1. + if x.shape[0] == 4: + x = x[:3] + src.append(x) + sample = torch.stack(src) + else: + dataset_test.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, len(dataset_test.dataset)) + sample_dict_test = next(iter(dataset_test.iterator)) + # sample_dict_test = concate_batch(sample_dict_test) + sample_wave_test = sample_dict_test['wave_re_batch_all'].to('cuda').float() + if cfg.MODEL.WAVE_BASED: + sample_spec_test = sample_dict_test['wave_spec_re_batch_all'].to('cuda').float() + sample_spec_amp_test = sample_dict_test['wave_spec_re_amp_batch_all'].to('cuda').float() + sample_spec_denoise_test = sample_dict_test['wave_spec_re_denoise_batch_all'].to('cuda').float() + # sample_spec_test = wave2spec(sample_wave_test,n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB) + else: + sample_spec_test = sample_dict_test['spkr_re_batch_all'].to('cuda').float() + sample_spec_denoise_test = None#sample_dict_test['wave_spec_re_denoise_batch_all'].to('cuda').float() + sample_label_test = sample_dict_test['label_batch_all'] + if cfg.MODEL.ECOG: + ecog_test = [sample_dict_test['ecog_re_batch_all'][i].to('cuda').float() for i in range(len(sample_dict_test['ecog_re_batch_all']))] + mask_prior_test = [sample_dict_test['mask_all'][i].to('cuda').float() for i in range(len(sample_dict_test['mask_all']))] + mni_coordinate_test = sample_dict_test['mni_coordinate_all'].to('cuda').float() + else: + ecog_test = None + mask_prior_test = None + mni_coordinate_test = None + sample_spec_mel_test = sample_dict_test['spkr_re_batch_all'].to('cuda').float() if cfg.MODEL.DO_MEL_GUIDE else None + on_stage_test = sample_dict_test['on_stage_re_batch_all'].to('cuda').float() + on_stage_wider_test = sample_dict_test['on_stage_wider_re_batch_all'].to('cuda').float() + # sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank)) + # sample = (sample / 127.5 - 1.) + # import pdb; pdb.set_trace() + duomask=True + # model.eval() + # Lrec = model(sample_spec_test, x_denoise = sample_spec_denoise_test,x_mel = sample_spec_mel_test,ecog=ecog_test if cfg.MODEL.ECOG else None, mask_prior=mask_prior_test if cfg.MODEL.ECOG else None, on_stage = on_stage_test,on_stage_wider = on_stage_wider_test, ae = not cfg.MODEL.ECOG, tracker = tracker_test, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate_test,debug = False,x_amp=sample_spec_amp_test,hamonic_bias = False) + # save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=0,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=True) + n_iter = 0 + for epoch in range(cfg.TRAIN.TRAIN_EPOCHS): + model.train() + + # batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) + model.train() + need_permute = False + epoch_start_time = time.time() + i = 0 + for sample_dict_train in tqdm(iter(dataset.iterator)): + n_iter +=1 + # import pdb; pdb.set_trace() + # sample_dict_train = concate_batch(sample_dict_train) + i += 1 + wave_orig = sample_dict_train['wave_re_batch_all'].to('cuda').float() + if cfg.MODEL.WAVE_BASED: + # x_orig = wave2spec(wave_orig,n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB) + x_orig = sample_dict_train['wave_spec_re_batch_all'].to('cuda').float() + x_orig_amp = sample_dict_train['wave_spec_re_amp_batch_all'].to('cuda').float() + x_orig_denoise = sample_dict_train['wave_spec_re_denoise_batch_all'].to('cuda').float() + else: + x_orig = sample_dict_train['spkr_re_batch_all'].to('cuda').float() + x_orig_denoise = None#sample_dict_train['wave_spec_re_denoise_batch_all'].to('cuda').float() + + on_stage = sample_dict_train['on_stage_re_batch_all'].to('cuda').float() + on_stage_wider = sample_dict_train['on_stage_wider_re_batch_all'].to('cuda').float() + words = sample_dict_train['word_batch_all'].to('cuda').long() + words = words.view(words.shape[0]*words.shape[1]) + labels = sample_dict_train['label_batch_all'] + if cfg.MODEL.ECOG: + ecog = [sample_dict_train['ecog_re_batch_all'][j].to('cuda').float() for j in range(len(sample_dict_train['ecog_re_batch_all']))] + mask_prior = [sample_dict_train['mask_all'][j].to('cuda').float() for j in range(len(sample_dict_train['mask_all']))] + mni_coordinate = sample_dict_train['mni_coordinate_all'].to('cuda').float() + else: + ecog = None + mask_prior = None + mni_coordinate = None + x = x_orig + x_mel = sample_dict_train['spkr_re_batch_all'].to('cuda').float() if cfg.MODEL.DO_MEL_GUIDE else None + # x.requires_grad = True + # apply_cycle = cfg.MODEL.CYCLE and True + # apply_w_classifier = cfg.MODEL.W_CLASSIFIER and True + # apply_gp = True + # apply_ppl = cfg.MODEL.APPLY_PPL and True + # apply_ppl_d = cfg.MODEL.APPLY_PPL_D and True + # apply_encoder_guide = (cfg.FINETUNE.ENCODER_GUIDE or cfg.MODEL.W_SUP) and True + # apply_sup = cfg.FINETUNE.SPECSUP + + if (cfg.MODEL.ECOG): + optimizer.zero_grad() + Lrec,tracker = model(x, x_denoise = x_orig_denoise,x_mel = x_mel,ecog=ecog, mask_prior=mask_prior, on_stage = on_stage,on_stage_wider = on_stage_wider, ae = False, tracker = tracker, encoder_guide=cfg.MODEL.W_SUP,duomask=duomask,mni=mni_coordinate,x_amp=x_orig_amp) + #print ('tracker',tracker,tracker.tracks) + #for key in ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff']: + # print ('tracker, ',key, tracker.tracks[key].mean(dim=0)) + ecog_key_list = ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff'] + if n_iter %10==0: + #print ('write tensorboard') + writer.add_scalars('data/loss_group', {key: tracker.tracks[key].mean(dim=0) for key in ecog_key_list}, n_iter) + for key in ecog_key_list: + writer.add_scalar('data/'+key, tracker.tracks[key].mean(dim=0), n_iter) + #pass #log in tensorboard later!! + #tracker Lae_a: 0.0536877, Lae_a_l2: 0.0538647, Lae_db: 0.1655398, Lae_db_l2: 0.1655714, Lloudness: 1.0384552, Lae_denoise: 0.0485827, Lamp: 0.0000148, Lae: 2.1787138, Lexp: -1.9956266, Lf0: 0.0000000, Ldiff: 0.0568467 + (Lrec).backward() + optimizer.step() + else: + optimizer.zero_grad() + Lrec,tracker = model(x, x_denoise = x_orig_denoise,x_mel = x_mel,ecog=None, mask_prior=None, on_stage = on_stage,on_stage_wider = on_stage_wider, ae = True, tracker = tracker, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate,debug = False,x_amp=x_orig_amp,hamonic_bias = False)#epoch<2) + #print ('tracker',tracker,tracker.tracks) + #for key in ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff']: + #print ('tracker, ',key, tracker.tracks[key].mean(dim=0)) + ecog_key_list = ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff'] + if n_iter %10==0: + #print ('write tensorboard') + writer.add_scalars('data/loss_group', {key: tracker.tracks[key].mean(dim=0) for key in ecog_key_list}, n_iter) + for key in ecog_key_list: + writer.add_scalar('data/'+key, tracker.tracks[key].mean(dim=0), n_iter) + (Lrec).backward() + optimizer.step() + + betta = 0.5 ** (cfg.TRAIN.BATCH_SIZE / (10 * 1000.0)) + model_s.lerp(model, betta,w_classifier = cfg.MODEL.W_CLASSIFIER) + + epoch_end_time = time.time() + per_epoch_ptime = epoch_end_time - epoch_start_time + #print ('test save sample') + #save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=sample_spec_amp_test) + #save_sample(x,ecog,mask_prior,mni_coordinate,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=x_orig_denoise,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=labels,mode='train',path=cfg.OUTPUT_DIR,tracker = tracker,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=x_orig_amp) + #print ('finish') + + if local_rank == 0: + print(2**(torch.tanh(model.encoder.formant_bandwitdh_slop))) + checkpointer.save("model_epoch%d" % epoch) + model.eval() + Lrec = model(sample_spec_test, x_denoise = sample_spec_denoise_test,x_mel = sample_spec_mel_test,ecog=ecog_test if cfg.MODEL.ECOG else None, mask_prior=mask_prior_test if cfg.MODEL.ECOG else None, on_stage = on_stage_test,on_stage_wider = on_stage_wider_test, ae = not cfg.MODEL.ECOG, tracker = tracker_test, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate_test,debug = False,x_amp=sample_spec_amp_test,hamonic_bias = False) + if epoch%1==0: + #first mode is test + reconaudio = save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=sample_spec_amp_test) + save_sample(x,ecog,mask_prior,mni_coordinate,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=x_orig_denoise,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=labels,mode='train',path=cfg.OUTPUT_DIR,tracker = tracker,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=x_orig_amp) + writer.add_audio('reconaudio', reconaudio, n_iter, sample_rate=16000) + writer.export_scalars_to_json(cfg.OUTPUT_DIR+"/all_scalars.json") + writer.close() + +if __name__ == "__main__": + gpu_count = torch.cuda.device_count() + cfg = get_cfg_defaults() + parser = argparse.ArgumentParser(description='formant') + parser.add_argument( + "-c", "--config-file", + default='configs/ecog_style2_a.yaml', + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + args = parser.parse_args() + + #if args.modeldir !='': + # cfg.OUTPUT_DIR = args.modeldir + run(train, cfg, description='StyleGAN', default_config='configs/ecog_style2_a.yaml', + world_size=gpu_count,args=args) diff --git a/train_formant_e.py b/train_formant_e.py new file mode 100644 index 00000000..b80b1f59 --- /dev/null +++ b/train_formant_e.py @@ -0,0 +1,427 @@ +# Copyright 2019-2020 Stanislav Pidhorskyi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import json +from os import terminal_size +import pdb +import torch.utils.data +from torchvision.utils import save_image +from net_formant import * +import os +import utils +from checkpointer import Checkpointer +from scheduler import ComboMultiStepLR +from custom_adam import LREQAdam +from dataloader_ecog import * +from tqdm import tqdm +from dlutils.pytorch import count_parameters +import dlutils.pytorch.count_parameters as count_param_override +from tracker import LossTracker +from model_formant import Model +from launcher import run +from defaults import get_cfg_defaults +import lod_driver +from PIL import Image +import numpy as np +from torch import autograd +from ECoGDataSet import concate_batch +from formant_systh import save_sample +import argparse +from tensorboardX import SummaryWriter + + + + +def train(cfg, logger, local_rank, world_size, distributed): + writer = SummaryWriter(cfg.OUTPUT_DIR) + torch.cuda.set_device(local_rank) + model = Model( + generator=cfg.MODEL.GENERATOR, + encoder=cfg.MODEL.ENCODER, + ecog_encoder_name=cfg.MODEL.MAPPING_FROM_ECOG, + spec_chans = cfg.DATASET.SPEC_CHANS, + n_formants = cfg.MODEL.N_FORMANTS, + n_formants_noise = cfg.MODEL.N_FORMANTS_NOISE, + n_formants_ecog = cfg.MODEL.N_FORMANTS_ECOG, + wavebased = cfg.MODEL.WAVE_BASED, + n_fft=cfg.MODEL.N_FFT, + noise_db=cfg.MODEL.NOISE_DB, + max_db=cfg.MODEL.MAX_DB, + with_ecog = cfg.MODEL.ECOG, + hidden_dim=cfg.MODEL.TRANSFORMER.HIDDEN_DIM, + dim_feedforward=cfg.MODEL.TRANSFORMER.DIM_FEEDFORWARD, + encoder_only=cfg.MODEL.TRANSFORMER.ENCODER_ONLY, + attentional_mask=cfg.MODEL.TRANSFORMER.ATTENTIONAL_MASK, + n_heads = cfg.MODEL.TRANSFORMER.N_HEADS, + non_local = cfg.MODEL.TRANSFORMER.NON_LOCAL, + do_mel_guide = cfg.MODEL.DO_MEL_GUIDE, + noise_from_data = cfg.MODEL.BGNOISE_FROMDATA, + specsup=cfg.FINETUNE.SPECSUP, + power_synth = cfg.MODEL.POWER_SYNTH, + onedconfirst=cfg.MODEL.ONEDCONFIRST, + rnn_type = cfg.MODEL.RNN_TYPE, + rnn_layers = cfg.MODEL.RNN_LAYERS, + compute_db_loudness=cfg.MODEL.RNN_COMPUTE_DB_LOUDNESS, + bidirection = cfg.MODEL.BIDIRECTION + ) + model.cuda(local_rank) + model.train() + + model_s = Model( + generator=cfg.MODEL.GENERATOR, + encoder=cfg.MODEL.ENCODER, + ecog_encoder_name=cfg.MODEL.MAPPING_FROM_ECOG, + spec_chans = cfg.DATASET.SPEC_CHANS, + n_formants = cfg.MODEL.N_FORMANTS, + n_formants_noise = cfg.MODEL.N_FORMANTS_NOISE, + n_formants_ecog = cfg.MODEL.N_FORMANTS_ECOG, + wavebased = cfg.MODEL.WAVE_BASED, + n_fft=cfg.MODEL.N_FFT, + noise_db=cfg.MODEL.NOISE_DB, + max_db=cfg.MODEL.MAX_DB, + with_ecog = cfg.MODEL.ECOG, + hidden_dim=cfg.MODEL.TRANSFORMER.HIDDEN_DIM, + dim_feedforward=cfg.MODEL.TRANSFORMER.DIM_FEEDFORWARD, + encoder_only=cfg.MODEL.TRANSFORMER.ENCODER_ONLY, + attentional_mask=cfg.MODEL.TRANSFORMER.ATTENTIONAL_MASK, + n_heads = cfg.MODEL.TRANSFORMER.N_HEADS, + non_local = cfg.MODEL.TRANSFORMER.NON_LOCAL, + do_mel_guide = cfg.MODEL.DO_MEL_GUIDE, + noise_from_data = cfg.MODEL.BGNOISE_FROMDATA, + specsup=cfg.FINETUNE.SPECSUP, + power_synth = cfg.MODEL.POWER_SYNTH, + onedconfirst=cfg.MODEL.ONEDCONFIRST, + rnn_type = cfg.MODEL.RNN_TYPE, + rnn_layers = cfg.MODEL.RNN_LAYERS, + compute_db_loudness=cfg.MODEL.RNN_COMPUTE_DB_LOUDNESS, + bidirection = cfg.MODEL.BIDIRECTION + ) + model_s.cuda(local_rank) + model_s.eval() + model_s.requires_grad_(False) + # print(model) + if distributed: + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank], + broadcast_buffers=False, + bucket_cap_mb=25, + find_unused_parameters=True) + model.device_ids = None + decoder = model.module.decoder + encoder = model.module.encoder + if hasattr(model.module,'ecog_encoder'): + ecog_encoder = model.module.ecog_encoder + if hasattr(model.module,'decoder_mel'): + decoder_mel = model.module.decoder_mel + else: + decoder = model.decoder + encoder = model.encoder + if hasattr(model,'ecog_encoder'): + ecog_encoder = model.ecog_encoder + if hasattr(model,'decoder_mel'): + decoder_mel = model.decoder_mel + + count_param_override.print = lambda a: logger.info(a) + + logger.info("Trainable parameters generator:") + count_parameters(decoder) + + logger.info("Trainable parameters discriminator:") + count_parameters(encoder) + + arguments = dict() + arguments["iteration"] = 0 + + if cfg.MODEL.ECOG: + if cfg.MODEL.SUPLOSS_ON_ECOGF: + optimizer = LREQAdam([ + {'params': ecog_encoder.parameters()} + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + else: + optimizer = LREQAdam([ + {'params': ecog_encoder.parameters()}, + {'params': decoder.parameters()}, + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + + else: + if cfg.MODEL.DO_MEL_GUIDE: + optimizer = LREQAdam([ + {'params': encoder.parameters()}, + {'params': decoder.parameters()}, + {'params': decoder_mel.parameters()}, + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + else: + optimizer = LREQAdam([ + {'params': encoder.parameters()}, + {'params': decoder.parameters()} + ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) + + scheduler = ComboMultiStepLR(optimizers= + {'optimizer': optimizer}, + milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, + gamma=cfg.TRAIN.LEARNING_DECAY_RATE, + reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) + model_dict = { + 'encoder': encoder, + 'generator': decoder, + } + if hasattr(model,'ecog_encoder'): + model_dict['ecog_encoder'] = ecog_encoder + if hasattr(model,'decoder_mel'): + model_dict['decoder_mel'] = decoder_mel + if local_rank == 0: + model_dict['encoder_s'] = model_s.encoder + model_dict['generator_s'] = model_s.decoder + if hasattr(model_s,'ecog_encoder'): + model_dict['ecog_encoder_s'] = model_s.ecog_encoder + if hasattr(model_s,'decoder_mel'): + model_dict['decoder_mel_s'] = model_s.decoder_mel + + tracker = LossTracker(cfg.OUTPUT_DIR) + tracker_test = LossTracker(cfg.OUTPUT_DIR,test=True) + + auxiliary = { + 'optimizer': optimizer, + 'scheduler': scheduler, + 'tracker': tracker, + 'tracker_test':tracker_test, + } + + checkpointer = Checkpointer(cfg, + model_dict, + auxiliary, + logger=logger, + save=local_rank == 0) + + extra_checkpoint_data = checkpointer.load(ignore_last_checkpoint=False,ignore_auxiliary=cfg.FINETUNE.FINETUNE,file_name='output/10231600/model_epoch56.pth') + logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) + + arguments.update(extra_checkpoint_data) + + + with open('train_param.json','r') as rfile: + param = json.load(rfile) + # data_param, train_param, test_param = param['Data'], param['Train'], param['Test'] + dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS,param=param) + dataset_test = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS,train=False,param=param) + # noise_dist = dataset.noise_dist + noise_dist = torch.from_numpy(dataset.noise_dist).to('cuda').float() + if cfg.MODEL.BGNOISE_FROMDATA: + model_s.noise_dist_init(noise_dist) + model.noise_dist_init(noise_dist) + rnd = np.random.RandomState(3456) + # latents = rnd.randn(len(dataset_test.dataset), cfg.MODEL.LATENT_SPACE_SIZE) + # samplez = torch.tensor(latents).float().cuda() + + + if cfg.DATASET.SAMPLES_PATH: + path = cfg.DATASET.SAMPLES_PATH + src = [] + with torch.no_grad(): + for filename in list(os.listdir(path))[:32]: + img = np.asarray(Image.open(os.path.join(path, filename))) + if img.shape[2] == 4: + img = img[:, :, :3] + im = img.transpose((2, 0, 1)) + x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1. + if x.shape[0] == 4: + x = x[:3] + src.append(x) + sample = torch.stack(src) + else: + dataset_test.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, len(dataset_test.dataset)) + sample_dict_test = next(iter(dataset_test.iterator)) + # sample_dict_test = concate_batch(sample_dict_test) + sample_wave_test = sample_dict_test['wave_re_batch_all'].to('cuda').float() + if cfg.MODEL.WAVE_BASED: + sample_spec_test = sample_dict_test['wave_spec_re_batch_all'].to('cuda').float() + sample_spec_amp_test = sample_dict_test['wave_spec_re_amp_batch_all'].to('cuda').float() + sample_spec_denoise_test = sample_dict_test['wave_spec_re_denoise_batch_all'].to('cuda').float() + # sample_spec_test = wave2spec(sample_wave_test,n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB) + else: + sample_spec_test = sample_dict_test['spkr_re_batch_all'].to('cuda').float() + sample_spec_denoise_test = None#sample_dict_test['wave_spec_re_denoise_batch_all'].to('cuda').float() + sample_label_test = sample_dict_test['label_batch_all'] + if cfg.MODEL.ECOG: + ecog_test = [sample_dict_test['ecog_re_batch_all'][i].to('cuda').float() for i in range(len(sample_dict_test['ecog_re_batch_all']))] + mask_prior_test = [sample_dict_test['mask_all'][i].to('cuda').float() for i in range(len(sample_dict_test['mask_all']))] + mni_coordinate_test = sample_dict_test['mni_coordinate_all'].to('cuda').float() + else: + ecog_test = None + mask_prior_test = None + mni_coordinate_test = None + sample_spec_mel_test = sample_dict_test['spkr_re_batch_all'].to('cuda').float() if cfg.MODEL.DO_MEL_GUIDE else None + on_stage_test = sample_dict_test['on_stage_re_batch_all'].to('cuda').float() + on_stage_wider_test = sample_dict_test['on_stage_wider_re_batch_all'].to('cuda').float() + # sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank)) + # sample = (sample / 127.5 - 1.) + # import pdb; pdb.set_trace() + duomask=True + # model.eval() + # Lrec = model(sample_spec_test, x_denoise = sample_spec_denoise_test,x_mel = sample_spec_mel_test,ecog=ecog_test if cfg.MODEL.ECOG else None, mask_prior=mask_prior_test if cfg.MODEL.ECOG else None, on_stage = on_stage_test,on_stage_wider = on_stage_wider_test, ae = not cfg.MODEL.ECOG, tracker = tracker_test, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate_test,debug = False,x_amp=sample_spec_amp_test,hamonic_bias = False) + # save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=0,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=True) + n_iter = 0 + for epoch in range(cfg.TRAIN.TRAIN_EPOCHS): + model.train() + + # batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) + model.train() + need_permute = False + epoch_start_time = time.time() + i = 0 + for sample_dict_train in tqdm(iter(dataset.iterator)): + n_iter +=1 + # import pdb; pdb.set_trace() + # sample_dict_train = concate_batch(sample_dict_train) + i += 1 + wave_orig = sample_dict_train['wave_re_batch_all'].to('cuda').float() + if cfg.MODEL.WAVE_BASED: + # x_orig = wave2spec(wave_orig,n_fft=cfg.MODEL.N_FFT,noise_db=cfg.MODEL.NOISE_DB,max_db=cfg.MODEL.MAX_DB) + x_orig = sample_dict_train['wave_spec_re_batch_all'].to('cuda').float() + x_orig_amp = sample_dict_train['wave_spec_re_amp_batch_all'].to('cuda').float() + x_orig_denoise = sample_dict_train['wave_spec_re_denoise_batch_all'].to('cuda').float() + else: + x_orig = sample_dict_train['spkr_re_batch_all'].to('cuda').float() + x_orig_denoise = None#sample_dict_train['wave_spec_re_denoise_batch_all'].to('cuda').float() + + on_stage = sample_dict_train['on_stage_re_batch_all'].to('cuda').float() + on_stage_wider = sample_dict_train['on_stage_wider_re_batch_all'].to('cuda').float() + words = sample_dict_train['word_batch_all'].to('cuda').long() + words = words.view(words.shape[0]*words.shape[1]) + labels = sample_dict_train['label_batch_all'] + if cfg.MODEL.ECOG: + ecog = [sample_dict_train['ecog_re_batch_all'][j].to('cuda').float() for j in range(len(sample_dict_train['ecog_re_batch_all']))] + mask_prior = [sample_dict_train['mask_all'][j].to('cuda').float() for j in range(len(sample_dict_train['mask_all']))] + mni_coordinate = sample_dict_train['mni_coordinate_all'].to('cuda').float() + else: + ecog = None + mask_prior = None + mni_coordinate = None + x = x_orig + x_mel = sample_dict_train['spkr_re_batch_all'].to('cuda').float() if cfg.MODEL.DO_MEL_GUIDE else None + # x.requires_grad = True + # apply_cycle = cfg.MODEL.CYCLE and True + # apply_w_classifier = cfg.MODEL.W_CLASSIFIER and True + # apply_gp = True + # apply_ppl = cfg.MODEL.APPLY_PPL and True + # apply_ppl_d = cfg.MODEL.APPLY_PPL_D and True + # apply_encoder_guide = (cfg.FINETUNE.ENCODER_GUIDE or cfg.MODEL.W_SUP) and True + # apply_sup = cfg.FINETUNE.SPECSUP + + if (cfg.MODEL.ECOG): + optimizer.zero_grad() + Lrec,tracker = model(x, x_denoise = x_orig_denoise,x_mel = x_mel,ecog=ecog, mask_prior=mask_prior, on_stage = on_stage,on_stage_wider = on_stage_wider, ae = False, tracker = tracker, encoder_guide=cfg.MODEL.W_SUP,duomask=duomask,mni=mni_coordinate,x_amp=x_orig_amp) + #print ('tracker',tracker,tracker.tracks) + #for key in ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff']: + # print ('tracker, ',key, tracker.tracks[key].mean(dim=0)) + #Lae_a: 0.0938077, Lae_a_l2: 0.0938695, Lae_db: 0.2939660, Lae_db_l2: 0.2939744, Lrec: 0.2407907, loudness: 0.5835954, f0_hz: 0.3419901, amplitudes: 0.2974630, amplitude_formants_hamon: 2.7913144, freq_formants_hamon_hz: 4.6389437, amplitude_formants_noise: 0.2281826, freq_formants_noise_hz: 1.5711578, bandwidth_formants_noise: 1.5711578, Ldiff: 0.3452123, Lexp: -0.0683261 + ecog_key_list = ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lrec','loudness','f0_hz','amplitudes','amplitude_formants_hamon','freq_formants_hamon_hz','amplitude_formants_noise','freq_formants_noise_hz','bandwidth_formants_noise','Ldiff','Lexp'] + metric_key_list = ['loudness_metric','f0_metric','amplitudes_metric','amplitude_formants_hamon_metric','freq_formants_hamon_hz_metric_2','freq_formants_hamon_hz_metric_6','amplitude_formants_noise_metric'] + if n_iter %10==0: + #print ('write tensorboard') + writer.add_scalars('data/loss_group', {key: tracker.tracks[key].mean(dim=0) for key in ecog_key_list}, n_iter) + writer.add_scalars('data/metric_group', {key: tracker.tracks[key].mean(dim=0) for key in metric_key_list}, n_iter) + for key in ecog_key_list: + writer.add_scalar('data/'+key, tracker.tracks[key].mean(dim=0), n_iter) + for key in metric_key_list: + writer.add_scalar('data/'+key, tracker.tracks[key].mean(dim=0), n_iter) + #pass #log in tensorboard later!! + #tracker Lae_a: 0.0536877, Lae_a_l2: 0.0538647, Lae_db: 0.1655398, Lae_db_l2: 0.1655714, Lloudness: 1.0384552, Lae_denoise: 0.0485827, Lamp: 0.0000148, Lae: 2.1787138, Lexp: -1.9956266, Lf0: 0.0000000, Ldiff: 0.0568467 + (Lrec).backward() + optimizer.step() + else: + optimizer.zero_grad() + Lrec,tracker = model(x, x_denoise = x_orig_denoise,x_mel = x_mel,ecog=None, mask_prior=None, on_stage = on_stage,on_stage_wider = on_stage_wider, ae = True, tracker = tracker, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate,debug = False,x_amp=x_orig_amp,hamonic_bias = False)#epoch<2) + print ('tracker',tracker,tracker.tracks) + #for key in ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff']: + #print ('tracker, ',key, tracker.tracks[key].mean(dim=0)) + ecog_key_list = ['Lae_a','Lae_a_l2','Lae_db','Lae_db_l2','Lloudness','Lae_denoise','Lamp','Lae','Lexp','Lf0','Ldiff'] + if n_iter %10==0: + #print ('write tensorboard') + writer.add_scalars('data/loss_group', {key: tracker.tracks[key].mean(dim=0) for key in ecog_key_list}, n_iter) + for key in ecog_key_list: + writer.add_scalar('data/'+key, tracker.tracks[key].mean(dim=0), n_iter) + (Lrec).backward() + optimizer.step() + + betta = 0.5 ** (cfg.TRAIN.BATCH_SIZE / (10 * 1000.0)) + model_s.lerp(model, betta,w_classifier = cfg.MODEL.W_CLASSIFIER) + + epoch_end_time = time.time() + per_epoch_ptime = epoch_end_time - epoch_start_time + #print ('test save sample') + #save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=sample_spec_amp_test) + #save_sample(x,ecog,mask_prior,mni_coordinate,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=x_orig_denoise,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=labels,mode='train',path=cfg.OUTPUT_DIR,tracker = tracker,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=x_orig_amp) + #print ('finish') + + if local_rank == 0: + print(2**(torch.tanh(model.encoder.formant_bandwitdh_slop))) + checkpointer.save("model_epoch%d" % epoch) + model.eval() + Lrec = model(sample_spec_test, x_denoise = sample_spec_denoise_test,x_mel = sample_spec_mel_test,ecog=ecog_test if cfg.MODEL.ECOG else None, mask_prior=mask_prior_test if cfg.MODEL.ECOG else None, on_stage = on_stage_test,on_stage_wider = on_stage_wider_test, ae = not cfg.MODEL.ECOG, tracker = tracker_test, encoder_guide=cfg.MODEL.W_SUP,pitch_aug=False,duomask=duomask,mni=mni_coordinate_test,debug = False,x_amp=sample_spec_amp_test,hamonic_bias = False) + if epoch%1==0: + #first mode is test + reconaudio = save_sample(sample_spec_test,ecog_test,mask_prior_test,mni_coordinate_test,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=sample_spec_denoise_test,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=sample_label_test,mode='test',path=cfg.OUTPUT_DIR,tracker = tracker_test,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=sample_spec_amp_test) + save_sample(x,ecog,mask_prior,mni_coordinate,encoder,decoder,ecog_encoder if cfg.MODEL.ECOG else None,x_denoise=x_orig_denoise,x_mel = sample_spec_mel_test,decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,epoch=epoch,label=labels,mode='train',path=cfg.OUTPUT_DIR,tracker = tracker,linear=cfg.MODEL.WAVE_BASED,n_fft=cfg.MODEL.N_FFT,duomask=duomask,x_amp=x_orig_amp) + writer.add_audio('reconaudio', reconaudio, n_iter, sample_rate=16000) + writer.export_scalars_to_json(cfg.OUTPUT_DIR+"/all_scalars.json") + writer.close() + +if __name__ == "__main__": + gpu_count = torch.cuda.device_count() + cfg = get_cfg_defaults() + + #if args.modeldir !='': + # cfg.OUTPUT_DIR = args.modeldir + parser = argparse.ArgumentParser(description='ecog formant model') + parser.add_argument( + "-c", "--config-file", + default='configs/ecog_style2.yaml', + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument('--ONEDCONFIRST', type=int,default=1, + help='use one d conv before lstm') + parser.add_argument('--RNN_TYPE', type=str,default='LSTM', + help='LSTM or GRU') + parser.add_argument('--RNN_LAYERS', type=int,default=4, + help='lstm layers') + parser.add_argument('--RNN_COMPUTE_DB_LOUDNESS', type=int,default=1, + help='RNN_COMPUTE_DB_LOUDNESS') + parser.add_argument('--BIDIRECTION', type=int,default=1, + help='BIDIRECTION') + parser.add_argument('--MAPPING_FROM_ECOG', type=str,default='ECoGMappingBottlenecklstm_pure', + help='MAPPING_FROM_ECOG') + parser.add_argument('--OUTPUT_DIR', type=str,default='output/ecog_11021800_lstmpure', + help='OUTPUT_DIR') + args = parser.parse_args() + + cfg.MODEL.ONEDCONFIRST = True if args.ONEDCONFIRST==1 else False + cfg.MODEL.RNN_TYPE = args.RNN_TYPE + cfg.MODEL.RNN_LAYERS = args.RNN_LAYERS + cfg.MODEL.RNN_COMPUTE_DB_LOUDNESS = True if args.RNN_COMPUTE_DB_LOUDNESS==1 else False + cfg.MODEL.BIDIRECTION = True if args.BIDIRECTION==1 else False + cfg.OUTPUT_DIR = args.OUTPUT_DIR +'_{}_{}_bi_{}_1dconv_{}'.format(args.RNN_TYPE,args.RNN_LAYERS,str(cfg.MODEL.BIDIRECTION),str(cfg.MODEL.ONEDCONFIRST)) + cfg.MODEL.MAPPING_FROM_ECOG = args.MAPPING_FROM_ECOG + + run(train, cfg, description='StyleGAN', default_config='configs/ecog_style2.yaml', + world_size=gpu_count,args=args) diff --git a/train_param.json b/train_param.json new file mode 100644 index 00000000..6f7ac467 --- /dev/null +++ b/train_param.json @@ -0,0 +1,56 @@ +{ + "Prod":true, + "SpecBands":64, + "SelectRegion":["AUDITORY","BROCA","MOTO","SENSORY"], + "BlockRegion":[], + "UseGridOnly":true, + "ReshapeAsGrid":false, + "SeqLen":128, + "DOWN_TF_FS": 125, + "DOWN_ECOG_FS": 125, + "Subj":{ + "NY717":{ + "Crop": null, + "Task": ["VisRead","SenComp","PicN","AudN","AudRep"], + "TestNum":[10,10,10,10,10] + }, + "NY742":{ + "Crop": null, + "Task": ["VisRead","SenComp","PicN","AudN","AudRep"], + "TestNum":[10,10,10,10,10] + }, + "NY749":{ + "Crop": null, + "Task": ["VisRead","SenComp","PicN","AudN","AudRep"], + "TestNum":[10,10,10,10,10] + } + }, + "Train":{ + "lr": 0.001, + "gamma": 0.8, + "no_cuda": false, + "batch_size": 10, + "num_epochs": 1000, + "save_model": true, + "save_interval": 50, + "save_dir": "/scratch/rw1691/connectivity/ECoG/Connectivity/CKpts/", + "log_interval": 100, + "ahead_onset": 32, + "loss": "L2", + "lam_reg": 0.01 + }, + "Test":{ + "test_interval": 5, + "batch_size": 10, + "ahead_onset": 32 + }, + "Analyze":{ + "epoch": 899, + "batch_size":2, + "SeqLen": 400, + "ahead_onset": 200, + "save_path": "/scratch/rw1691/connectivity/ECoG/Connectivity/AnalyzeResult" + + } + +} diff --git a/transformer_models/._backbone.py b/transformer_models/._backbone.py new file mode 100644 index 00000000..58dd64eb Binary files /dev/null and b/transformer_models/._backbone.py differ diff --git a/transformer_models/._detr.py b/transformer_models/._detr.py new file mode 100644 index 00000000..c1286390 Binary files /dev/null and b/transformer_models/._detr.py differ diff --git a/transformer_models/._position_encoding.py b/transformer_models/._position_encoding.py new file mode 100644 index 00000000..e6112c50 Binary files /dev/null and b/transformer_models/._position_encoding.py differ diff --git a/transformer_models/._transformer.py b/transformer_models/._transformer.py new file mode 100644 index 00000000..c9952e4a Binary files /dev/null and b/transformer_models/._transformer.py differ diff --git a/transformer_models/__init__.py b/transformer_models/__init__.py new file mode 100644 index 00000000..a3f26531 --- /dev/null +++ b/transformer_models/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from .detr import build + + +def build_model(args): + return build(args) diff --git a/transformer_models/backbone.py b/transformer_models/backbone.py new file mode 100644 index 00000000..96680932 --- /dev/null +++ b/transformer_models/backbone.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/transformer_models/detr.py b/transformer_models/detr.py new file mode 100644 index 00000000..d58dcd60 --- /dev/null +++ b/transformer_models/detr.py @@ -0,0 +1,349 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn + +from util import box_ops +from util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, + is_dist_avail_and_initialized) + +from .backbone import build_backbone +from .matcher import build_matcher +from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm, + dice_loss, sigmoid_focal_loss) +from .transformer import build_transformer + + +class DETR(nn.Module): + """ This is the DETR module that performs object detection """ + def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) + self.backbone = backbone + self.aux_loss = aux_loss + + def forward(self, samples: NestedTensor): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + src, mask = features[-1].decompose() + assert mask is not None + hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] + + outputs_class = self.class_embed(hs) + outputs_coord = self.bbox_embed(hs).sigmoid() + out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer('empty_weight', empty_weight) + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = F.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # convert to [x0, y0, x1, y1] format + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + + return results + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build(args): + num_classes = 20 if args.dataset_file != 'coco' else 91 + if args.dataset_file == "coco_panoptic": + num_classes = 250 + device = torch.device(args.device) + + backbone = build_backbone(args) + + transformer = build_transformer(args) + + model = DETR( + backbone, + transformer, + num_classes=num_classes, + num_queries=args.num_queries, + aux_loss=args.aux_loss, + ) + if args.masks: + model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) + matcher = build_matcher(args) + weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} + weight_dict['loss_giou'] = args.giou_loss_coef + if args.masks: + weight_dict["loss_mask"] = args.mask_loss_coef + weight_dict["loss_dice"] = args.dice_loss_coef + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'boxes', 'cardinality'] + if args.masks: + losses += ["masks"] + criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, + eos_coef=args.eos_coef, losses=losses) + criterion.to(device) + postprocessors = {'bbox': PostProcess()} + if args.masks: + postprocessors['segm'] = PostProcessSegm() + if args.dataset_file == "coco_panoptic": + is_thing_map = {i: i <= 90 for i in range(201)} + postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) + + return model, criterion, postprocessors diff --git a/transformer_models/matcher.py b/transformer_models/matcher.py new file mode 100644 index 00000000..0c291473 --- /dev/null +++ b/transformer_models/matcher.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +def build_matcher(args): + return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) diff --git a/transformer_models/position_encoding.py b/transformer_models/position_encoding.py new file mode 100644 index 00000000..3ba13535 --- /dev/null +++ b/transformer_models/position_encoding.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) #BxCxWxH + return pos + +class PositionEmbeddingLearnedECoG(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, SeqLength, num_pos_feats=256): + super().__init__() + self.SeqLength = SeqLength + self.time_embed = nn.Embedding(SeqLength, num_pos_feats) + # self.elec_embed = nn.Embedding(50, num_pos_feats) + self.elec_embed = nn.Linear(3,num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.time_embed.weight) + # nn.init.uniform_(self.elec_embed.weight) + + def forward(self, x): + # x: MNI Bx3xE + e = x.shape[-1] + t = self.SeqLength + j = torch.arange(t, device=x.device) + x = x.permute(0,2,1) # BxEx3 + elec_emb = self.elec_embed(x) #BxExnum_pos_feats + time_emb = self.time_embed(j) #Txnum_pos_feats + pos = torch.cat([ + elec_emb.unsqueeze(1).repeat(1, t, 1, 1), + time_emb.unsqueeze(1).repeat(1, e, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1), + ], dim=-1).permute(0, 3, 1, 2)# BxCxTxE + return pos + +def build_position_encoding(SeqLength,hidden_dim,method='MNI'): + N_steps = hidden_dim // 2 + if method in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif method in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + elif method in ('MNI'): + position_embedding = PositionEmbeddingLearnedECoG(SeqLength,N_steps) + else: + raise ValueError(f"not supported {method}") + + return position_embedding diff --git a/transformer_models/segmentation.py b/transformer_models/segmentation.py new file mode 100644 index 00000000..edfc32ef --- /dev/null +++ b/transformer_models/segmentation.py @@ -0,0 +1,363 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +This file provides the definition of the convolutional heads used to predict masks, as well as the losses +""" +import io +from collections import defaultdict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from PIL import Image + +import util.box_ops as box_ops +from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list + +try: + from panopticapi.utils import id2rgb, rgb2id +except ImportError: + pass + + +class DETRsegm(nn.Module): + def __init__(self, detr, freeze_detr=False): + super().__init__() + self.detr = detr + + if freeze_detr: + for p in self.parameters(): + p.requires_grad_(False) + + hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead + self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) + self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) + + def forward(self, samples: NestedTensor): + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.detr.backbone(samples) + + bs = features[-1].tensors.shape[0] + + src, mask = features[-1].decompose() + assert mask is not None + src_proj = self.detr.input_proj(src) + hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) + + outputs_class = self.detr.class_embed(hs) + outputs_coord = self.detr.bbox_embed(hs).sigmoid() + out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} + if self.detr.aux_loss: + out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord) + + # FIXME h_boxes takes the last one computed, keep this in mind + bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) + + seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) + outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) + + out["pred_masks"] = outputs_seg_masks + return out + + +def _expand(tensor, length: int): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + +class MaskHeadSmallConv(nn.Module): + """ + Simple convolutional head, using group norm. + Upsampling is done using a FPN approach + """ + + def __init__(self, dim, fpn_dims, context_dim): + super().__init__() + + inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] + self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, dim) + self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) + + self.dim = dim + + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]): + x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) + + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + +class MHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + nn.init.zeros_(self.k_linear.bias) + nn.init.zeros_(self.q_linear.bias) + nn.init.xavier_uniform_(self.k_linear.weight) + nn.init.xavier_uniform_(self.q_linear.weight) + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask: Optional[Tensor] = None): + q = self.q_linear(q) + k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) + weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) + + if mask is not None: + weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) + weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) + weights = self.dropout(weights) + return weights + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +class PostProcessSegm(nn.Module): + def __init__(self, threshold=0.5): + super().__init__() + self.threshold = threshold + + @torch.no_grad() + def forward(self, results, outputs, orig_target_sizes, max_target_sizes): + assert len(orig_target_sizes) == len(max_target_sizes) + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs["pred_masks"].squeeze(2) + outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) + outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() + + for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = F.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() + + return results + + +class PostProcessPanoptic(nn.Module): + """This class converts the output of the model to the final panoptic result, in the format expected by the + coco panoptic API """ + + def __init__(self, is_thing_map, threshold=0.85): + """ + Parameters: + is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether + the class is a thing (True) or a stuff (False) class + threshold: confidence threshold: segments with confidence lower than this will be deleted + """ + super().__init__() + self.threshold = threshold + self.is_thing_map = is_thing_map + + def forward(self, outputs, processed_sizes, target_sizes=None): + """ This function computes the panoptic prediction from the model's predictions. + Parameters: + outputs: This is a dict coming directly from the model. See the model doc for the content. + processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the + model, ie the size after data augmentation but before batching. + target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size + of each prediction. If left to None, it will default to the processed_sizes + """ + if target_sizes is None: + target_sizes = processed_sizes + assert len(processed_sizes) == len(target_sizes) + out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] + assert len(out_logits) == len(raw_masks) == len(target_sizes) + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.cpu().tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): + # we filter empty queries and detection below threshold + scores, labels = cur_logits.softmax(-1).max(-1) + keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) + cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores = cur_scores[keep] + cur_classes = cur_classes[keep] + cur_masks = cur_masks[keep] + cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0) + cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + assert len(cur_boxes) == len(cur_classes) + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_classes): + if not self.is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) + + np_seg_img = ( + torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() + ) + m_id = torch.from_numpy(rgb2id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_classes.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor( + [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device + ) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_classes[i].item() + segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) + del cur_classes + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + preds.append(predictions) + return preds diff --git a/transformer_models/transformer.py b/transformer_models/transformer.py new file mode 100644 index 00000000..611e4877 --- /dev/null +++ b/transformer_models/transformer.py @@ -0,0 +1,302 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +class Transformer(nn.Module): + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False,encoder_only = False): + super().__init__() + self.encoder_only = encoder_only + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + if not encoder_only: + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC, mask = NxHxW + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + try: + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + except: + import pdb;pdb.set_trace() + if not self.encoder_only: + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + tgt = torch.zeros_like(query_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + return hs.permute(1, 2, 0), memory.permute(1, 2, 0).view(bs, c, h, w) + else: + return None,memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/transformer_models/transformer_nonlocal.py b/transformer_models/transformer_nonlocal.py new file mode 100644 index 00000000..d34b87e7 --- /dev/null +++ b/transformer_models/transformer_nonlocal.py @@ -0,0 +1,342 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +import lreq as ln + + +class NonlocalAttention(nn.Module): + def __init__(self, d_model, nhead, dropout, att_denorm=2): + super(NonlocalAttention, self).__init__() + # Channel multiplier + self.d_model = d_model + self.heads = nhead + self.dropout = dropout + self.att_denorm = att_denorm + self.theta = ln.Conv2d(d_model, d_model // self.att_denorm, 1,1,0, bias=True) #query + self.phi = ln.Conv2d(d_model, d_model // self.att_denorm, 1,1,0, bias=True) #key + self.g = ln.Conv2d(d_model, d_model, 1,1,0, bias=True) #value + self.drop = nn.Dropout(dropout) + + # # Learnable gain parameter + # self.gamma = P(torch.tensor(0.), requires_grad=True) + + def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): + # Apply convs + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + theta = self.theta(query) + phi = F.max_pool2d(self.phi(key), [2,1]) + g = F.max_pool2d(self.g(value), [2,1]) + # Perform reshapes + self.theta_ = theta.reshape(-1, self.d_model // self.att_denorm//self.heads, self.heads ,query.shape[2] * query.shape[3]) + self.phi_ = phi.reshape(-1, self.d_model // self.att_denorm//self.heads, self.heads, key.shape[2] * key.shape[3] // 2) + g = g.reshape(-1, self.d_model//self.heads, self.heads, value.shape[2] * value.shape[3] // 2) + # Matmul and softmax to get attention maps + self.beta = F.softmax(torch.einsum('bchi,bchj->bhij',self.theta_, self.phi_), -1) + self.beta = self.drop(self.beta) + # self.beta = F.softmax(torch.bmm(self.theta_, self.phi_), -1) + # Attention map times g path + o = torch.einsum('bchj,bhij->bchi',g, self.beta).reshape(-1, self.d_model, query.shape[2], query.shape[3]) + # o = self.o(torch.bmm(g, self.beta.transpose(1,2)).view(-1, self.inputs // 2, x.shape[2], x.shape[3])) + return o, self.beta + +class Transformer(nn.Module): + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False,encoder_only = False): + super().__init__() + self.encoder_only = encoder_only + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm([d_model,128,80]) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + if not encoder_only: + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm([d_model,128,1]) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC, mask = NxHxW, query_embed: LxC + bs, c, h, w = src.shape + # src = src.flatten(2).permute(2, 0, 1) + # pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + if not self.encoder_only: + query_embed = query_embed.unsqueeze(0).repeat(bs, 1, 1) + query_embed = query_embed.permute(0,2,1).unsqueeze(-1) #BxCxLx1 + tgt = torch.zeros_like(query_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + hs = hs.squeeze(-1) + return hs, memory + else: + return None,memory + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = NonlocalAttention(d_model, nhead, dropout=dropout) + + # Implementation of Feedforward model + self.linear1 = ln.Conv2d(d_model, dim_feedforward,1) + self.dropout = nn.Dropout(dropout) + self.linear2 = ln.Conv2d(dim_feedforward, d_model,1) + + self.norm1 = nn.LayerNorm([d_model,128,80]) + self.norm2 = nn.LayerNorm([d_model,128,80]) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = NonlocalAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = NonlocalAttention(d_model, nhead, dropout=dropout) + + # Implementation of Feedforward model + self.linear1 = ln.Conv2d(d_model, dim_feedforward,1) + self.dropout = nn.Dropout(dropout) + self.linear2 = ln.Conv2d(dim_feedforward, d_model,1) + + self.norm1 = nn.LayerNorm([d_model,128,1]) + self.norm2 = nn.LayerNorm([d_model,128,1]) + self.norm3 = nn.LayerNorm([d_model,128,1]) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/transformer_models/util/__init__.py b/transformer_models/util/__init__.py new file mode 100644 index 00000000..168f9979 --- /dev/null +++ b/transformer_models/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/transformer_models/util/box_ops.py b/transformer_models/util/box_ops.py new file mode 100644 index 00000000..9c088e5b --- /dev/null +++ b/transformer_models/util/box_ops.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/transformer_models/util/misc.py b/transformer_models/util/misc.py new file mode 100644 index 00000000..45d055d9 --- /dev/null +++ b/transformer_models/util/misc.py @@ -0,0 +1,416 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if float(torchvision.__version__[:3]) < 0.7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = NestedTensor.from_tensor_list(batch[0]) + return tuple(batch) + + +class NestedTensor(object): + def __init__(self, tensors, mask): + self.tensors = tensors + self.mask = mask + + def to(self, *args, **kwargs): + cast_tensor = self.tensors.to(*args, **kwargs) + cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None + return type(self)(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + @classmethod + def from_tensor_list(cls, tensor_list): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + max_size = tuple(max(s) for s in zip(*[img.shape for img in tensor_list])) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = (len(tensor_list),) + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return cls(tensor, mask) + + def __repr__(self): + return repr(self.tensors) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/transformer_models/util/plot_utils.py b/transformer_models/util/plot_utils.py new file mode 100644 index 00000000..4a03f43f --- /dev/null +++ b/transformer_models/util/plot_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Plotting utilities to visualize training logs. +""" +import torch +import pandas as pd +from pathlib import Path +import seaborn as sns +import matplotlib.pyplot as plt + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0): + dfs = [pd.read_json(Path(p) / 'log.txt', lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs