diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/GlobalRetrieval.py b/GlobalRetrieval.py new file mode 100644 index 0000000..68ada29 --- /dev/null +++ b/GlobalRetrieval.py @@ -0,0 +1,73 @@ +import torch +import faiss +import numpy as np +import time +import cv2 +import argparse +import os +from Matcher import Dinov2Matcher + +# #Define a function that normalizes embeddings and add them to the index +def add_vector_to_index(embedding, index): + #Convert to float32 numpy + vector = np.float32(embedding) + #Normalize vector: important to avoid wrong results when searching + faiss.normalize_L2(vector) + #Add to index + index.add(vector) + + +#Create Faiss index using FlatL2 type with 768 dimensions (DINOv2-base) as this + +def build_database(base_path, database_save_path): + + index = faiss.IndexFlatL2(768) + # Extract features + t0 = time.time() + image_paths = [] + for root, dirs, files in os.walk(base_path): + for file in files: + # 检查文件是否是图像类型 + if file.lower().endswith('.jpg'): + file_path = os.path.join(root, file) + image_paths.append(file_path) + + for image_path in image_paths: + image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + with torch.no_grad(): + image_tensor, _, _ = dm.prepare_image(image) + global_features = dm.extract_global_features(image_tensor) + add_vector_to_index(global_features, index) + + # Store the index locally + print('Extraction done in :', time.time()-t0) + faiss.write_index(index, database_save_path) + + +######### Retrieve the image to search +def retrieval_from_database(image_path, database_path): + #Extract the features + with torch.no_grad(): + image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + image_tensor, _, _ = dm.prepare_image(image) + global_features = dm.extract_global_features(image_tensor) + + vector = np.float32(global_features) + faiss.normalize_L2(vector) + + #Read the index file and perform search of top-3 images + index = faiss.read_index(database_path) + d,i = index.search(vector, 5) + print('distances:', d, 'indexes:', i) + +if __name__=='__main__': + dm = Dinov2Matcher(half_precision=False) + + parser = argparse.ArgumentParser(description='Global_trerieval') + parser.add_argument('--database_path', type=str, required=True, help='Path to save the database') + parser.add_argument('--data_path', type=str, required=True, help='Path of non-polyp reference images') + parser.add_argument('--image_path', type=str, required=True, help='Path of Query image') + args = parser.parse_args() + + build_database(args.data_path, args.database_path) + retrieval_from_database(args.image_path, args.database_path) \ No newline at end of file diff --git a/LocalMatching.py b/LocalMatching.py new file mode 100644 index 0000000..c451529 --- /dev/null +++ b/LocalMatching.py @@ -0,0 +1,121 @@ +from PIL import Image +import numpy as np +import matplotlib.pyplot as plt +import cv2 +from sklearn.neighbors import NearestNeighbors +import argparse +import matplotlib.pyplot as plt +from matplotlib.patches import ConnectionPatch +from sklearn.cluster import DBSCAN +from Matcher import Dinov2Matcher + +patch_size = 14 + + +def plot_matching_figure(image1, image2, xyA_list, xyB_list, save_path): + fig = plt.figure(figsize=(10,10)) + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + + ax1.imshow(image1) + ax2.imshow(image2) + + for xyA, xyB in zip(xyA_list, xyB_list): + con = ConnectionPatch(xyA=xyB, xyB=xyA, coordsA="data", coordsB="data", + axesA=ax2, axesB=ax1, color="red") + ax2.add_artist(con) + + fig.tight_layout() + fig.show() + fig.savefig(save_path) + + + +def MaskProposer(origin_image, origin_mask, target_image, target_mask, matching_figure_save_path=None): +# Init Dinov2Matcher + dm = Dinov2Matcher(half_precision=False) + # Extract image1 features + image1 = cv2.cvtColor(cv2.imread(origin_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + mask1 = cv2.imread(origin_mask, cv2.IMREAD_COLOR)[:,:,0] > 127 + image_tensor1, grid_size1, resize_scale1 = dm.prepare_image(image1) + features1 = dm.extract_local_features(image_tensor1) + print(features1.shape) + # Extract image2 features + image2 = cv2.cvtColor(cv2.imread(target_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + image_tensor2, grid_size2, resize_scale2 = dm.prepare_image(image2) + features2 = dm.extract_local_features(image_tensor2) + + # Build knn using features from image1, and query all features from image2 + knn = NearestNeighbors(n_neighbors=1) + knn.fit(features1) + distances, match2to1 = knn.kneighbors(features2) + match2to1 = np.array(match2to1) + + xyA_list = [] + xyB_list = [] + distances_list = [] + + for idx2, (dist, idx1) in enumerate(zip(distances, match2to1)): + row, col = dm.idx_to_source_position(idx1, grid_size1, resize_scale1) + xyA = (col, row) + if not mask1[int(row), int(col)]: continue # skip if feature is not on the object + row, col = dm.idx_to_source_position(idx2, grid_size2, resize_scale2) + xyB = (col, row) + xyB_list.append(xyB) + xyA_list.append(xyA) + distances_list.append(dist[0]) + + #Filter by distance + if len(xyA_list) > 30: + zip_list = list(zip(distances_list, xyA_list, xyB_list)) + zip_list.sort(key=lambda x: x[0]) + distances_list, xyA_list, xyB_list = zip(*zip_list) + xyA_list = xyA_list[:30] + xyB_list = xyB_list[:30] + + + if matching_figure_save_path is not None: + plot_matching_figure(image1, image2, xyA_list, xyB_list, matching_figure_save_path) + + # DBSCAN clustering + X = np.array(xyB_list) + clustering = DBSCAN(eps=2*patch_size+1 , min_samples=1).fit(X) + labels = clustering.labels_ + + # find the cluster with the most number of points + unique_labels, counts = np.unique(labels, return_counts=True) + max_label = unique_labels[np.argmax(counts)] + new_list = [xyB for i, xyB in enumerate(xyB_list) if labels[i] == max_label] + + #find the min-col and max-col of the cluster + min_col = np.min([xy[0] for xy in new_list]) - patch_size//2 + max_col = np.max([xy[0] for xy in new_list]) + patch_size//2 + #find the min-row and max-row of the cluster + min_row = np.min([xy[1] for xy in new_list]) - patch_size//2 + max_row = np.max([xy[1] for xy in new_list]) + patch_size//2 + + mask = np.zeros((image2.shape[0], image2.shape[1])) + mask[int(min_row):int(max_row), int(min_col):int(max_col)] = 255 + mask = mask.astype(np.uint8) + mask = Image.fromarray(mask).convert('L') + mask.save(target_mask) + return mask + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='LocalMatching') + parser.add_argument('--ref_image', type=str, required=True, help='Path of Reference image') + parser.add_argument('--ref_mask', type=str, required=True, help='Path of Reference mask') + parser.add_argument('--query_image', type=str, required=True, help='Path of Query image') + parser.add_argument('--mask_proposal', type=str, required=True, help='Save Path of Mask proposal') + parser.add_argument('--save_fig', type=str, default=None, help='Save the Matching Figure') + + args = parser.parse_args() + + mask = MaskProposer(origin_image=args.ref_image, + origin_mask=args.ref_mask, + target_image=args.query_image, + target_mask=args.mask_proposal, + matching_figure_save_path=args.save_fig + ) \ No newline at end of file diff --git a/Matcher.py b/Matcher.py new file mode 100644 index 0000000..1112154 --- /dev/null +++ b/Matcher.py @@ -0,0 +1,99 @@ +import torch +from transformers import AutoImageProcessor, AutoModel +from PIL import Image +import numpy as np +import torchvision.transforms as transforms +import torch.nn as nn + + +class GeM(nn.Module): + def __init__(self, p=3, eps=1e-6): + super(GeM, self).__init__() + self.p = nn.Parameter(torch.ones(1) * p) + self.eps = eps + + def forward(self, x): + # 假设x的形状是 (batch_size, sequence_length, feature_dim) + # 对每个样本的所有序列特征进行GeM pooling + x = x.clamp(min=self.eps).pow(self.p) + x = torch.mean(x, dim=1) # 沿着sequence_length维度取平均 + x = x.pow(1. / self.p) + return x + + def __repr__(self): + return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' + + +class Dinov2Matcher: + + def __init__(self, repo_name="facebookresearch/dinov2", model_name="dinov2_vitb14", smaller_edge_size=512, half_precision=False, device="cuda"): + self.repo_name = repo_name + self.model_name = model_name + # self.smaller_edge_size = smaller_edge_size + self.half_precision = half_precision + self.device = device + + if self.half_precision: + self.model = torch.hub.load(repo_or_dir=repo_name, model=model_name).half().to(self.device) + else: + self.model = torch.hub.load(repo_or_dir=repo_name, model=model_name).to(self.device) + + # self.model = AutoModel.from_pretrained('/data/lsy/workspace/hf_ckp/models--facebook--dinov2-base').to(device) + + self.model.eval() + + self.transform = transforms.Compose([ + # transforms.Resize(size=smaller_edge_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # imagenet defaults + ]) + self.gem_loss = GeM() + + # https://github.com/facebookresearch/dinov2/blob/255861375864acdd830f99fdae3d9db65623dafe/notebooks/features.ipynb + def prepare_image(self, rgb_image_numpy): + image = Image.fromarray(rgb_image_numpy) + image_tensor = self.transform(image) + resize_scale = image.width / image_tensor.shape[2] + + # Crop image to dimensions that are a multiple of the patch size + height, width = image_tensor.shape[1:] # C x H x W + cropped_width, cropped_height = width - width % self.model.patch_size, height - height % self.model.patch_size # crop a bit from right and bottom parts + image_tensor = image_tensor[:, :cropped_height, :cropped_width] + + grid_size = (cropped_height // self.model.patch_size, cropped_width // self.model.patch_size) + return image_tensor, grid_size, resize_scale + + def prepare_mask(self, mask_image_numpy, grid_size, resize_scale): + cropped_mask_image_numpy = mask_image_numpy[:int(grid_size[0]*self.model.patch_size*resize_scale), :int(grid_size[1]*self.model.patch_size*resize_scale)] + image = Image.fromarray(cropped_mask_image_numpy) + resized_mask = image.resize((grid_size[1], grid_size[0]), resample=Image.Resampling.NEAREST) + resized_mask = np.asarray(resized_mask).flatten() + return resized_mask + + def extract_global_features(self, image_tensor): + with torch.inference_mode(): + if self.half_precision: + image_batch = image_tensor.unsqueeze(0).half().to(self.device) + else: + image_batch = image_tensor.unsqueeze(0).to(self.device) + + tokens = self.model.get_intermediate_layers(image_batch)[0].mean(dim=1).detach().cpu() + + return tokens.numpy() + + def extract_local_features(self, image_tensor): + with torch.inference_mode(): + if self.half_precision: + image_batch = image_tensor.unsqueeze(0).half().to(self.device) + else: + image_batch = image_tensor.unsqueeze(0).to(self.device) + + tokens = self.model.get_intermediate_layers(image_batch)[0].squeeze() + return tokens.cpu().numpy() + + + def idx_to_source_position(self, idx, grid_size, resize_scale): + row = (idx // grid_size[1])*self.model.patch_size*resize_scale + self.model.patch_size / 2 + col = (idx % grid_size[1])*self.model.patch_size*resize_scale + self.model.patch_size / 2 + return row, col + diff --git a/README.md b/README.md new file mode 100644 index 0000000..7e5ebc2 --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +# Polyp-Gen + +# 🛠Setup + +```bash +git clone https://github.com/Saint-lsy/Polyp-Gen.git +cd Polyp-Gen +conda create -n PolypGen python=3.10 +conda activate PolypGen +pip install -r requirements.txt +``` + +## Data Preparation +This model was trained by [LDPolypVideo](https://github.com/dashishi/LDPolypVideo-Benchmark) dataset. + +We filtered out some low-quality images with blurry, reflective, and ghosting effects, and finally select 55,883 samples including 29,640 polyp frames and 26,243 non-polyp frames. +## Checkpoint +You can download the chekpoints of our Polyp_Gen on [HuggingFace](https://huggingface.co/Saint-lsy/Polyp-Gen-sd2-inpainting/tree/main) + +## Sampling with Specified Mask +``` +python sample_one_image.py +``` + +## Sampling with Mask Proposer +The first step is building database and Global Retrieval. +```bash +python GlobalRetrieval.py --data_path /path/of/non-polyp/images --database_path /path/to/build/database --image_path /path/of/query/image/ +``` +The second step is Local Matching for query image. +```bash +python LocalMatching.py --ref_image /path/ref/image --ref_mask /path/ref/mask --query_image /path/query/image --mask_proposal /path/to/save/mask +``` +One Demo of LocalMatching +```bash +python LocalMatching.py --ref_image demos/img_1513_neg.jpg --ref_mask demos/mask_1513.jpg --query_image demos/img_1592_neg.jpg --mask_proposal gen_mask.jpg +``` + +The third step is using the generated Mask to sample. diff --git a/checkpoint/download_checkpoint.txt b/checkpoint/download_checkpoint.txt new file mode 100644 index 0000000..e69de29 diff --git a/data/data.txt b/data/data.txt new file mode 100644 index 0000000..c5ec831 --- /dev/null +++ b/data/data.txt @@ -0,0 +1 @@ +# Download Modified LDPolypVideo \ No newline at end of file diff --git a/demos/demo_img.jpg b/demos/demo_img.jpg new file mode 100644 index 0000000..13f819a Binary files /dev/null and b/demos/demo_img.jpg differ diff --git a/demos/demo_mask.jpg b/demos/demo_mask.jpg new file mode 100644 index 0000000..1d8a939 Binary files /dev/null and b/demos/demo_mask.jpg differ diff --git a/demos/img_1513.jpg b/demos/img_1513.jpg new file mode 100644 index 0000000..0c9f036 Binary files /dev/null and b/demos/img_1513.jpg differ diff --git a/demos/img_1513_neg.jpg b/demos/img_1513_neg.jpg new file mode 100644 index 0000000..36b43cd Binary files /dev/null and b/demos/img_1513_neg.jpg differ diff --git a/demos/img_1592.jpg b/demos/img_1592.jpg new file mode 100644 index 0000000..8ba1e89 Binary files /dev/null and b/demos/img_1592.jpg differ diff --git a/demos/img_1592_neg.jpg b/demos/img_1592_neg.jpg new file mode 100644 index 0000000..8e9a365 Binary files /dev/null and b/demos/img_1592_neg.jpg differ diff --git a/demos/mask_1513.jpg b/demos/mask_1513.jpg new file mode 100644 index 0000000..c4851da Binary files /dev/null and b/demos/mask_1513.jpg differ diff --git a/metrics/calculate_FID.py b/metrics/calculate_FID.py new file mode 100644 index 0000000..a5cc80b --- /dev/null +++ b/metrics/calculate_FID.py @@ -0,0 +1,253 @@ +import os +import pathlib +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +import numpy as np +import torch +import torchvision.transforms as TF +from PIL import Image +from scipy import linalg +from torch.nn.functional import adaptive_avg_pool2d +try: + from tqdm import tqdm +except ImportError: + # If tqdm is not available, provide a mock version of it + def tqdm(x): + return x + +from pytorch_fid.inception import InceptionV3 +IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', + 'tif', 'tiff', 'webp'} + +class ImagePathDataset(torch.utils.data.Dataset): + def __init__(self, files, transforms=None): + self.files = files + self.transforms = transforms + + def __len__(self): + return len(self.files) + + def __getitem__(self, i): + path = self.files[i] + img = Image.open(path).convert('RGB') + if self.transforms is not None: + img = self.transforms(img) + return img + + +def get_activations(files, model, batch_size=50, dims=2048, device='cpu', + num_workers=1): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : Batch size of images for the model to process at once. + Make sure that the number of samples is a multiple of + the batch size, otherwise some samples are ignored. This + behavior is retained to match the original FID score + implementation. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + + Returns: + -- A numpy array of dimension (num images, dims) that contains the + activations of the given tensor when feeding inception with the + query tensor. + """ + model.eval() + + if batch_size > len(files): + print(('Warning: batch size is bigger than the data size. ' + 'Setting batch size to data size')) + batch_size = len(files) + + transform = TF.Compose([ + TF.Resize(512), + TF.CenterCrop(512), + TF.ToTensor(), + ]) + + dataset = ImagePathDataset(files, transform) + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers) + + pred_arr = np.empty((len(files), dims)) + + start_idx = 0 + + for batch in tqdm(dataloader): + batch = batch.to(device) + + with torch.no_grad(): + pred = model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + + pred_arr[start_idx:start_idx + pred.shape[0]] = pred + + start_idx = start_idx + pred.shape[0] + + return pred_arr + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def calculate_activation_statistics(files, model, batch_size=50, dims=2048, + device='cpu', num_workers=1): + """Calculation of the statistics used by the FID. + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : The images numpy array is split into batches with + batch size batch_size. A reasonable batch size + depends on the hardware. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations(files, model, batch_size, dims, device, num_workers) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def compute_statistics_of_path(path, model, batch_size, dims, device, + num_workers=1): + if path.endswith('.npz'): + with np.load(path) as f: + m, s = f['mu'][:], f['sigma'][:] + else: + path = pathlib.Path(path) + files = sorted([file for ext in IMAGE_EXTENSIONS + for file in path.glob('*.{}'.format(ext))]) + m, s = calculate_activation_statistics(files, model, batch_size, + dims, device, num_workers) + + return m, s + + +def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): + """Calculates the FID of two paths""" + for p in paths: + if not os.path.exists(p): + raise RuntimeError('Invalid path: %s' % p) + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + model = InceptionV3([block_idx]).to(device) + + m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, + dims, device, num_workers) + m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, + dims, device, num_workers) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + + return fid_value + + +def save_fid_stats(paths, batch_size, device, dims, num_workers=1): + """Calculates the FID of two paths""" + if not os.path.exists(paths[0]): + raise RuntimeError('Invalid path: %s' % paths[0]) + + if os.path.exists(paths[1]): + raise RuntimeError('Existing output file: %s' % paths[1]) + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + model = InceptionV3([block_idx]).to(device) + + print(f"Saving statistics for {paths[0]}") + + m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, + dims, device, num_workers) + + np.savez_compressed(paths[1], mu=m1, sigma=s1) + +def main(): + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--target_path', type=str, help='Path to the generated images') + parser.add_argument('--real_path', type=str, help='Path to the real images') + args = parser.parse_args() + + fid_value = calculate_fid_given_paths([args.real_path, args.target_path], + batch_size=50, + device='cuda', + dims=2048, + num_workers=1) + + print("FID:", fid_value) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/metrics/calculate_IS.py b/metrics/calculate_IS.py new file mode 100644 index 0000000..54c901b --- /dev/null +++ b/metrics/calculate_IS.py @@ -0,0 +1,91 @@ +from skimage.metrics import peak_signal_noise_ratio +from skimage.metrics import structural_similarity +import numpy as np +from PIL import Image +import os +import lpips +import torch +import torchvision.transforms as transforms +import torch.nn as nn +import torchvision.models as models +from scipy.stats import entropy +import argparse + + +# target_path = "/data/lsy/workspace/DiffAM/sample_imgs/ckp10k/" + +def main(target_path): + + + tgt_imgs_path = [] + for root, dirs, files in os.walk(target_path): + for file in files: + # 检查文件是否是图像类型 + if file.lower().endswith('.jpg'): + file_path = os.path.join(root, file) + tgt_imgs_path.append(file_path) + + # 使用预训练的InceptionV3模型 + inception_model = models.inception_v3(pretrained=True, transform_input=False) + inception_model.eval() + inception_model.cuda() # 如果有GPU的话 + + # 图像预处理 + preprocess = transforms.Compose([ + transforms.Resize(299), + transforms.CenterCrop(299), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + tgt_imgs = [preprocess(Image.open(img)) for img in tgt_imgs_path] + tgt_imgs = torch.stack(tgt_imgs) + + # 计算Inception Score + + def inception_score(imgs, inception_model, batch_size=32, splits=1): + N = len(imgs) + dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor + dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) + + # 获取预测值 + preds = np.zeros((N, 1000)) + for i, batch in enumerate(dataloader, 0): + batch = batch.type(dtype) + batch_size_i = batch.size()[0] + + with torch.no_grad(): + pred = inception_model(batch)[0] + + preds[i*batch_size:i*batch_size + batch_size_i] = pred.cpu().data.numpy() + + # 计算p(y|x) + preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True) + + # 计算KL散度 + split_scores = [] + for k in range(splits): + part = preds[k * (N // splits): (k+1) * (N // splits), :] + py = np.mean(part, axis=0) + scores = [] + for i in range(part.shape[0]): + pyx = part[i, :] + scores.append(entropy(pyx, py)) + split_scores.append(np.exp(np.mean(scores))) + + return np.mean(split_scores), np.std(split_scores) + + + # 计算Inception Score + inception_score_mean, inception_score_std = inception_score(tgt_imgs, inception_model) + print("Inception Score: ", inception_score_mean, inception_score_std) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--target_path', type=str, required=True, help='Path to the target directory') + args = parser.parse_args() + main(args.target_path) + + +#python IS_LPIPS.py --target_path /data/lsy/workspace/DiffAM/sample_imgs/random_mask/ckp30k \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..359f778 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +clip==0.2.0 +datasets==2.19.2 +diffusers==0.29.0 +einops==0.7.0 +imageio==2.34.1 +matplotlib==3.9.0 +opencv-python==4.10.0.84 +packaging==24.0 +pandas==2.2.2 +protobuf==3.20.1 +pytorch-fid==0.3.0 +safetensors==0.4.3 +scikit-image==0.23.2 +scikit-learn==1.5.0 +scipy==1.13.0 +seaborn==0.13.2 +timm==0.6.13 +tokenizers==0.19.1 +tomlkit==0.12.0 +toolz==0.12.1 +torch==2.1.2 +torchattacks==3.5.1 +torchaudio==2.1.1 +torchmetrics==1.3.2 +torchvision==0.16.1 +tqdm==4.66.2 +transformers==4.40.1 +wandb==0.17.2 \ No newline at end of file diff --git a/sample_one_img.py b/sample_one_img.py new file mode 100644 index 0000000..9b6749d --- /dev/null +++ b/sample_one_img.py @@ -0,0 +1,27 @@ +from diffusers import StableDiffusionInpaintPipeline +import torch +from PIL import Image + + +model_path = "/ckp/path/" +image = Image.open("demos/demo_img.jpg") +mask_image = Image.open("demos/demo_mask.jpg") + +prompt = 'Polyp' + +pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_path, + revision="fp16", + torch_dtype=torch.float16, + safety_checker=None, +) + +pipe = pipe.to("cuda") + +gen_image = pipe(prompt=prompt, image=image, mask_image=mask_image, + width=image.size[0], height=image.size[1], num_inference_steps=50, + ).images[0] + +gen_image.save("sample.jpg") + + diff --git a/samples.py b/samples.py new file mode 100644 index 0000000..9126e8b --- /dev/null +++ b/samples.py @@ -0,0 +1,78 @@ +from diffusers import StableDiffusionInpaintPipeline, UNet2DConditionModel +import torch +import pandas as pd +from PIL import Image +import numpy as np +import os +import argparse + + +def main(args): + save_path = args.save_path + checkpoint_path = args.checkpoint_path + test_csv = args.test_file + data_path = args.dataset_path + + pipe = StableDiffusionInpaintPipeline.from_pretrained( + checkpoint_path, + revision="fp16", + torch_dtype=torch.float16, + safety_checker=None, + ) + + if os.path.exists(save_path) == False: + os.makedirs(save_path) + + if not os.path.exists(save_path + "/0"): + os.makedirs(save_path + "/0") + if not os.path.exists(save_path + "/1"): + os.makedirs(save_path + "/1") + + pipe = pipe.to('cuda') + + for index, row in test_csv.iterrows(): + image_path = data_path + "/" + row['image'] + mask_path = data_path + "/" + row['seg'] + image = Image.open(image_path) + mask_image = Image.open(mask_path) + + #Enlarge the range of Mask + mask = np.array(mask_image) + rows, cols = np.where(mask == 255) + min_row, max_row = np.min(rows), np.max(rows) + min_col, max_col = np.min(cols), np.max(cols) + #Enlarge the range of Mask + min_row = max(0, min_row - 15) + max_row = min(mask.shape[0], max_row + 15) + min_col = max(0, min_col - 15) + max_col = min(mask.shape[1], max_col + 15) + mask = np.zeros_like(mask) + mask[min_row:max_row, min_col:max_col] = 255 + mask_image = Image.fromarray(mask, "L") + + if row['label'] == 0: + prompt = "Normal" + else: + prompt = "Polyp" + + gen_image = pipe(prompt=prompt, image=image, mask_image=mask_image, + width=image.size[0], height=image.size[1], num_inference_steps=50, + ).images[0] + + if index < 1000: + gen_image.save(save_path + "/0/img_" + str(index) + ".jpg") + else: + gen_image.save(save_path + "/1/img_" + str(index) + ".jpg") + + + +if __name__ == "__main__": + #add Args + parser = argparse.ArgumentParser(description='Sample inpainting') + parser.add_argument('--model_path', type=str, required=True, help='Path to save the checkpoints') + parser.add_argument('--data_path', type=str, required=True, help='Path of Dataset') + parser.add_argument('--test_file', type=str, required=True, help='Test *.csv file') + parser.add_argument('--save_path', type=str, required=True, help='Path to save the inpainted images') + parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to the UNet checkpoint') + args = parser.parse_args() + main(args) diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 0000000..81f18ad --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,24 @@ +export CUDA_VISIBLE_DEVICES=2 +export MODEL_NAME="stabilityai--stable-diffusion-2-inpainting" +export DATASET_NAME="\path\to\dataset" +export FILE_NAME="\path\to\train.csv" +accelerate launch train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --output_dir "out/dir" \ + --train_data_dir=$DATASET_NAME \ + --train_file=$FILE_NAME \ + --resolution=512 \ + --train_batch_size=2 \ + --max_train_steps=300000 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-5 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=500 \ + --checkpointing_steps=5000 \ + --validation_steps=1000 \ + --validation_image "eval/img" \ + --validation_mask "eval/mask" \ + --validation_prompt "Polyp" \ + --num_validation_images=2 \ + --tracker_project_name="Polyp-Gen" \ + --report_to="wandb" \ \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..6a4fc13 --- /dev/null +++ b/train.py @@ -0,0 +1,1391 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +import random +import shutil +import contextlib +import gc +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset, concatenate_datasets +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionInpaintPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr +from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from PIL import Image, ImageDraw + + + +if is_wandb_available(): + import wandb + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.29.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/naruto-blip-captions": ("image", "text"), +} + + +def save_model_card( + args, + repo_id: str, + images: list = None, + repo_folder: str = None, +): + img_str = "" + if len(images) > 0: + image_grid = make_image_grid(images, 1, len(args.validation_prompts)) + image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) + img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" + + model_description = f""" +# Text-to-image finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n +{img_str} + +## Pipeline usage + +You can use the pipeline like so: + +```python +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) +prompt = "{args.validation_prompts[0]}" +image = pipeline(prompt).images[0] +image.save("my_image.png") +``` + +## Training info + +These are the key hyperparameters used during training: + +* Epochs: {args.num_train_epochs} +* Learning rate: {args.learning_rate} +* Batch size: {args.train_batch_size} +* Gradient accumulation steps: {args.gradient_accumulation_steps} +* Image resolution: {args.resolution} +* Mixed-precision: {args.mixed_precision} + +""" + wandb_info = "" + if is_wandb_available(): + wandb_run_url = None + if wandb.run is not None: + wandb_run_url = wandb.run.url + + if wandb_run_url is not None: + wandb_info = f""" +More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). +""" + + model_description += wandb_info + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=args.pretrained_model_name_or_path, + model_description=model_description, + inference=True, + ) + + tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, step, is_final_validation=False): + logger.info("Running validation... ") + + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=accelerator.unwrap_model(vae), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + safety_checker=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt) and len(args.validation_image) == len(args.validation_mask): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + validation_masks = args.validation_mask + else: + raise ValueError( + "number of `args.validation_image`, `args.validation_mask`, and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + + inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") + + for validation_prompt, validation_image, validation_mask in zip(validation_prompts, validation_images, validation_masks): + validation_image = Image.open(validation_image).convert("RGB") + validation_mask = Image.open(validation_mask).convert("L") + + #Enlarge the range of Mask + mask = np.array(validation_mask) + rows, cols = np.where(mask == 255) + min_row, max_row = np.min(rows), np.max(rows) + min_col, max_col = np.min(cols), np.max(cols) + #Enlarge the range of Mask + min_row = max(0, min_row - 15) + max_row = min(mask.shape[0], max_row + 15) + min_col = max(0, min_col - 15) + max_col = min(mask.shape[1], max_col + 15) + mask = np.zeros_like(mask) + mask[min_row:max_row, min_col:max_col] = 255 + validation_mask = Image.fromarray(mask, "L") + + + validation_masked_image = Image.composite(Image.new('RGB', (validation_image.size[0], validation_image.size[1]), (0, 0, 0)), validation_image, validation_mask.convert("L")) + + images = [] + + for _ in range(args.num_validation_images): + with inference_ctx: + image = pipeline( + validation_prompt, + image=validation_image, + mask_image=validation_mask, + width=560, + height=480, + num_inference_steps=50, + generator=generator + ).images[0] + + images.append(image) + + image_logs.append( + {"validation_image": validation_image, + "images": images, + "validation_prompt": validation_prompt, + "validation_mask": validation_masked_image + } + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_mask = log["validation_mask"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + formatted_images.append(np.asarray(validation_mask)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_mask = log["validation_mask"] + + formatted_images.append(wandb.Image(validation_image, caption="origin image")) + formatted_images.append(wandb.Image(validation_mask, caption="mask image")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default='/data/lsy/workspace/hf_ckp/models--runwayml--stable-diffusion-v1-5', + # required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--train_file", + type=str, + default=None, + help=( + "The path of the Train.csv" + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--use_all_splits", + type=bool, + default=True, + help="Whether to use all splits of the dataset. If set, the dataset is concatenated into a single split." + " Otherwise, only the first split is used.", + ) + + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--dream_training", + action="store_true", + help=( + "Use the DREAM training method, which makes training more efficient and accurate at the ", + "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210", + ), + ) + parser.add_argument( + "--dream_detail_preservation", + type=float, + default=1.0, + help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_prompt", + type=str, + default=["A cake on the table."], + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=["examples/brushnet/src/test_image.jpg"], + nargs="+", + help=( + "A set of paths to the paintingnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--validation_mask", + type=str, + default=["examples/brushnet/src/test_mask.jpg"], + nargs="+", + help=( + "A set of paths to the paintingnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=5, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + + # Freeze vae and text_encoder and set unet to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.train() + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for _ in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + + dataset = load_dataset( + "csv", + data_files=args.train_file + ) + + def add_prefix(example): + example["image"] = os.path.join(args.train_data_dir, example["image"]) + example["seg"] = os.path.join(args.train_data_dir, example["seg"]) + return example + + + # Preprocessing the datasets. + + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + dataset = dataset.map(add_prefix) + dataset = dataset.cast_column("image", datasets.Image()) + dataset = dataset.cast_column("seg", datasets.Image()) + + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + #No InterpolationMode.BILINEAR for mask + mask_transforms = transforms.Compose( + [ + transforms.Grayscale(), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.NEAREST), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + ] + ) + + def create_random_convex_polygon_mask(raw_mask, y, x, height, width, min_vertices=5, max_vertices=12): + """ + 创建一个包含随机凸多边形的掩码。 + ... + """ + num_vertices = random.randint(min_vertices, max_vertices) + + # 多边形中心点 + center_x = x + width // 2 + center_y = y + height // 2 + + # 平均半径,决定多边形大小 + radius = min(width, height) // 2 + + # 创建顶点 + vertices = [] + for i in range(num_vertices): + # 随机角度 + angle = 2 * math.pi * i / num_vertices + random.uniform(-math.pi/16, math.pi/16) + # 随机半径 + r = radius + random.randint(-radius//4, radius//4) + # 计算坐标 + vx = center_x + int(r * math.cos(angle)) + vy = center_y + int(r * math.sin(angle)) + vertices.append((vx, vy)) + + + raw_mask = Image.fromarray(raw_mask).convert("L") + draw = ImageDraw.Draw(raw_mask) + + # 绘制多边形 + draw.polygon(vertices, outline=1, fill=1) + + # 转换图像为numpy数组 + mask = np.array(raw_mask) * 255 + + return mask + + + def gen_random_mask(mask, label): + # Generate random mask on Image type + width = random.randint(30, 200) + length = random.randint(30, 200) + mask = np.array(mask) + if np.sum(mask) == 0: + + x = random.randint(30, min(mask.shape[0] - width, mask.shape[0] - 30)) + y = random.randint(30, min(mask.shape[1] - length, mask.shape[1] - 30)) + if random.random() < 0.4: + mask[x:x + width, y:y + length] = 255 + else: + mask = create_random_convex_polygon_mask(mask, x, y, width, length) + else: + rows, cols = np.where(mask == 255) + min_row, max_row = np.min(rows), np.max(rows) + min_col, max_col = np.min(cols), np.max(cols) + + # Enlarge the mask + min_row = max(0, min_row - 15) + max_row = min(mask.shape[0], max_row + 15) + min_col = max(0, min_col - 15) + max_col = min(mask.shape[1], max_col + 15) + new_mask = np.zeros_like(mask) + if label == 0: + # Create new_mask and make intersection is 0 + x = random.randint(30, min(mask.shape[0] - width, mask.shape[0] - 30)) + y = random.randint(30, min(mask.shape[1] - length, mask.shape[1] - 30)) + + if random.random() > 0.5: + new_mask[x:x + width, y:y + length] = 255 + else: + new_mask = create_random_convex_polygon_mask(new_mask, x, y, width, length) + # if np.sum(np.bitwise_and(new_mask, mask)) != 0: + new_mask[min_row:max_row, min_col:max_col] = 0 + mask = new_mask + else: + if random.random() > 0.5: + new_mask = create_random_convex_polygon_mask(new_mask, min_row, min_col, max_row - min_row, max_col - min_col) + else: + new_mask[min_row:max_row, min_col:max_col] = 255 + mask = new_mask + mask = Image.fromarray(mask, "L") + return mask + + + def preprocess_train(examples): + + labels = examples["label"] + images = [image.convert("RGB") for image in examples[image_column]] + segs = [seg.convert("L") for seg in examples["seg"]] + captions = [] + for index in range(len(images)): + if labels[index] == 0: + segs[index] = gen_random_mask(segs[index], 0) + captions.append("Normal") + else: + + if random.random() > 0.6: + segs[index] = gen_random_mask(segs[index], 0) + captions.append("Normal") + else: + segs[index] = gen_random_mask(segs[index], 1) + captions.append("Polyp") + + + examples['text'] = captions + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["masks"] = [mask_transforms(seg) for seg in segs] + + examples["conditioning_pixel_values"] = [img * (1 - seg) for img, seg in zip(examples["pixel_values"], examples["masks"])] + examples["input_ids"] = tokenize_captions(examples) + + return examples + + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + masks = torch.stack([example["masks"] for example in examples]) + masks = masks.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.stack([example["input_ids"] for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "masks": masks, + "input_ids": input_ids, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + tracker_config.pop("validation_mask") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Convert masked images to latent space + masked_latents = vae.encode( + batch["conditioning_pixel_values"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) + ).latent_dist.sample() + masked_latents = masked_latents * vae.config.scaling_factor + + # resize the mask to latents shape as we concatenate the mask to the latents + masks = torch.nn.functional.interpolate( + batch["masks"], + size=( + latents.shape[-2], + latents.shape[-1] + ) + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # concatenate the noised latents with the mask and the masked latents + latent_model_input = torch.cat([noisy_latents, masks, masked_latents], dim=1) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + # Lesion-guided loss + lg_loss = F.mse_loss((masks * model_pred).float(), (masks * target).float(), reduction="mean") + loss = loss + 0.5 * lg_loss + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = mse_loss_weights / snr + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = mse_loss_weights / (snr + 1) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + revision=args.revision, + variant=args.variant, + ) + pipeline.save_pretrained(args.output_dir) + + # Run a final round of validation. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + brushnet=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + + if args.push_to_hub: + save_model_card(args, repo_id, image_logs, repo_folder=args.output_dir) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main()