Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to train DiffBIR for face retouching? #113

Open
luser350 opened this issue May 7, 2024 · 5 comments
Open

How to train DiffBIR for face retouching? #113

luser350 opened this issue May 7, 2024 · 5 comments

Comments

@luser350
Copy link

luser350 commented May 7, 2024

Hi, thanks for sharing your awesome work. I want to train DiffBIR on 1024 x1024 ffhqr dataset. I want to use ffhq dataset as input (1024x024).
I have modified codeformer getitem().

def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]:
        # load gt image
        img_gt = None
        while img_gt is None:
            # load meta file
            image_file = self.image_files[index]
            gt_path = image_file["image_path"]
            prompt = image_file["prompt"]
            img_gt = self.load_gt_image(gt_path)
            if img_gt is None:
                print(f"filed to load {gt_path}, try another image")
                index = random.randint(0, len(self) - 1)
        
        # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
        img_gt = (img_gt[..., ::-1] / 255.0).astype(np.float32)
        h, w, _ = img_gt.shape
        if np.random.uniform() < 0.5:
            prompt = ""
        
        # ------------------------ generate lq image ------------------------ #
        lq_path = gt_path.replace("ffhqr", "ffhq")
        img_lq = cv2.imread(lq_path)
        img_lq = img_lq.astype(np.float32)/255.0
        
        # BGR to RGB, [-1, 1]
        gt = (img_gt[..., ::-1] * 2 - 1).astype(np.float32)
        # BGR to RGB, [0, 1]
        lq = img_lq[..., ::-1].astype(np.float32)
        
        return gt, lq, prompt

I have generated the training.list and validation.list for ffhqr.
Please guide:

  1. Which pre-trained model I should use, the face_swinir_v1.ckpt or v1_face.pth.
  2. Does I need to train stage_1 or I can directly train stage_2
  3. I want input 1024x1024 and output 1024x1024
@luser350
Copy link
Author

luser350 commented May 9, 2024

@chxy95 @hejingwenhejingwen @xinntao @WenlongZhang0517 can somebody answer my question, please?

@0x3f3f3f3fun
Copy link
Collaborator

Hello!

  1. v1_face.pth. This checkpoint contains the weight of IRControlNet, which receives a smooth face image as condition and output a high-quality restoration result.
  2. I am not familiar with face retouching. I think you can directly train stage_2 with the original face image as conditions.
  3. If you want to train on a resolution of 1024, the first thing you need to do is to ensure that the pretrained VAE can encode and decode your image successfully, because it's originally trained on a resolution of 256. After that, you can directly use 1024 images as inputs, there are no other things to do.

Feel free to contact with me if you have any further questions.

@luser350
Copy link
Author

Hi, thanks for replying. Please tell me which VAE you are referring to. Please guide me on how to verify the VAE.

@0x3f3f3f3fun
Copy link
Collaborator

Hi, thanks for replying. Please tell me which VAE you are referring to. Please guide me on how to verify the VAE.

Follow the instructions below:

  1. load ControlLDM
  2. load an 1024x1024 image and convert it to a tensor $x$ (nchw, rgb, range in [-1,1])
  3. call ControlLDM.vae_encode() to encode $x$ to a latent code $z$
  4. call ControlLDM.vae_decode() to decode $z$ and get the result $x\prime$ (nchw, rgb, range in [-1,1])
  5. $x\prime$ should be very close to $x$

Here is an example (not tested):

from model.cldm import ControlLDM
from utils.common import instantiate_from_config
from omegaconf import OmegaConf
import torch

cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
# VAE is contained in pretrained SD
sd = torch.load("path/to/pretrained_sd_v2.1", map_location="cpu")
unused = cldm.load_pretrained_sd(sd)
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
cldm.eval().to("cuda")

# load image and convert to tensor
from PIL import Image
img = Image.open("xxx").convert("RGB")
x = ...

with torch.no_grad():
    z = cldm.vae_encode(x)
    x_decoded = cldm.vae_decode(z)

# convert x_decoded back to image 
img_decoded = ...

# save image and take a look at them...

@luser350
Copy link
Author

luser350 commented May 16, 2024

Hi @0x3f3f3f3fun I have completed the above step. Here is my script

from model.cldm import ControlLDM
from utils.common import instantiate_from_config
from omegaconf import OmegaConf
import torch
import torchvision.transforms as transforms
from PIL import Image

# Define the transformation
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 
])

cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
# VAE is contained in pretrained SD
sd = torch.load("pretrained/v2-1_512-ema-pruned.ckpt", map_location="cpu")
#unused = cldm.load_pretrained_sd(sd)
#print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
cldm.eval().to("cpu")

# load image and convert to tensor
img = Image.open("00006.png").convert("RGB")
x = transform(img)
x = x.unsqueeze(0)
#x = x.permute(0, 2, 3, 1)

with torch.no_grad():
    z = cldm.vae_encode(x)
    x_decoded = cldm.vae_decode(z)

# reverse normalization
x_decoded = x_decoded.squeeze()#.permute(1, 2, 0) 
x_decoded = (x_decoded * 0.5) + 0.5 
img_decoded = (transforms.ToPILImage()(x_decoded.cpu().clamp(0, 1)))

img_decoded.save("decoded_image.png")

I have commented out these two lines

unused = cldm.load_pretrained_sd(sd)
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")

It was giving a KeyError

Traceback (most recent call last):
  File "/home/luser350/Desktop/diffbir/DiffBIR/vae.py", line 16, in <module>
    unused = cldm.load_pretrained_sd(sd)
  File "/home/luser350/anaconda3/envs/diffbir/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/luser350/Desktop/diffbir/DiffBIR/model/cldm.py", line 51, in load_pretrained_sd
    init_sd[key] = sd[target_key].clone()
KeyError: 'model.diffusion_model.time_embed.0.weight'

My input image
00006
Decoded Image
decoded_image

The obvious result, since cldm was unable to load the sd model. So how to solve this KeyError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants