Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only train and infer stage 2 model #119

Open
tzayuan opened this issue May 26, 2024 · 1 comment
Open

only train and infer stage 2 model #119

tzayuan opened this issue May 26, 2024 · 1 comment

Comments

@tzayuan
Copy link

tzayuan commented May 26, 2024

Hi, @0x3f3f3f3fun
I want to train & inference only stage 2 model(stage 1 will implement alone by another model and generate restorationed image), if you could provide me a command structures?(how to only fine-tune stage 2 model and save it, and how to inference only stage 2 model) Thanks!

BRs,
tzayuan

@0x3f3f3f3fun
Copy link
Collaborator

0x3f3f3f3fun commented May 26, 2024

Hi :)
For inference only stage 2 model, you can take the following steps:

  1. Implement a custom pipeline without any restoration module in utils/helpers.py.
class CustomPipeline(Pipeline):

    def __init__(self, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
        super().__init__(None, cldm, diffusion, cond_fn, device)

    @count_vram_usage
    def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
        # In our experiments, the output of restoration module (a.k.a the condition for subsequent IRControlNet)
        # will be resized to a resolution >= 512, since both pretrained SD2 and IRControlNet are trained on 512x512.
        # Here we directly use the lq as condition.
        if min(lq.shape[2:]) < 512:
            clean = resize_short_edge_to(lq, size=512)
        return clean
  1. Implement a custom inference loop in utils/inference.py.
class CustomInferenceLoop(InferenceLoop):

    @count_vram_usage
    def init_stage1_model(self) -> None:
        # Nothing to do.
        pass

    def init_pipeline(self) -> None:
        # Instantiate our custom pipeline.
        self.pipeline = CustomPipeline(self.cldm, self.diffusion, self.cond_fn, self.args.device)
  1. Add the custom pipeline to inference.py as an option.
def main():
    args = parse_args()
    args.device = check_device(args.device)
    set_seed(args.seed)
    if args.version == "v1":
        V1InferenceLoop(args).run()
    else:
        supported_tasks = {
            "sr": BSRInferenceLoop,
            "dn": BIDInferenceLoop,
            "fr": BFRInferenceLoop,
            "fr_bg": UnAlignedBFRInferenceLoop,
            "custom": CustomInferenceLoop
        }
        supported_tasks[args.task](args).run()
        print("done!")

The provided code is directly wrote in the GitHub comment console and has not been tested, but I think it will work.

As for finetuning stage2 model, there are two ways to do it:

  • Apply the restoration model to low-quality images off-line, and write a new dataset class to load the gt-condition pairs.
  • Copy the model denifition to model directory, load it in train_stage2.py and replace SwinIR with your restoration model.

The second way is more recommended.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants