@@ -43,6 +43,21 @@ def _get_coreml_inputs(sample_inputs, args):
4343 ) for k , v in sample_inputs .items ()
4444 ]
4545
46+ # Simpler version of `DiagonalGaussianDistribution` with only needed calculations
47+ # as implemented in vae.py as part of the AutoencoderKL class
48+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L312
49+ # coremltools-6.1 does not yet implement the randn operation with the option of setting a random seed
50+ class CoreMLDiagonalGaussianDistribution (object ):
51+ def __init__ (self , parameters , noise ):
52+ self .parameters = parameters
53+ self .noise = noise
54+ self .mean , self .logvar = torch .chunk (parameters , 2 , dim = 1 )
55+ self .logvar = torch .clamp (self .logvar , - 30.0 , 20.0 )
56+ self .std = torch .exp (0.5 * self .logvar )
57+
58+ def sample (self ) -> torch .FloatTensor :
59+ x = self .mean + self .std * self .noise
60+ return x
4661
4762def compute_psnr (a , b ):
4863 """ Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
@@ -140,7 +155,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
140155
141156def quantize_weights_to_8bits (args ):
142157 for model_name in [
143- "text_encoder" , "vae_decoder" , "unet" , "unet_chunk1" ,
158+ "text_encoder" , "vae_decoder" , "vae_encoder" , " unet" , "unet_chunk1" ,
144159 "unet_chunk2" , "safety_checker"
145160 ]:
146161 out_path = _get_out_path (args , model_name )
@@ -190,6 +205,7 @@ def bundle_resources_for_swift_cli(args):
190205 # Compile model using coremlcompiler (Significantly reduces the load time for unet)
191206 for source_name , target_name in [("text_encoder" , "TextEncoder" ),
192207 ("vae_decoder" , "VAEDecoder" ),
208+ ("vae_encoder" , "VAEEncoder" ),
193209 ("unet" , "Unet" ),
194210 ("unet_chunk1" , "UnetChunk1" ),
195211 ("unet_chunk2" , "UnetChunk2" ),
@@ -453,6 +469,159 @@ def forward(self, z):
453469 gc .collect ()
454470
455471
472+ def convert_vae_encoder (pipe , args ):
473+ """ Converts the VAE Encoder component of Stable Diffusion
474+ """
475+ out_path = _get_out_path (args , "vae_encoder" )
476+ if os .path .exists (out_path ):
477+ logger .info (
478+ f"`vae_encoder` already exists at { out_path } , skipping conversion."
479+ )
480+ return
481+
482+ if not hasattr (pipe , "unet" ):
483+ raise RuntimeError (
484+ "convert_unet() deletes pipe.unet to save RAM. "
485+ "Please use convert_vae_encoder() before convert_unet()" )
486+
487+ sample_shape = (
488+ 1 , # B
489+ 3 , # C (RGB range from -1 to 1)
490+ (args .latent_h or pipe .unet .config .sample_size ) * 8 , # H
491+ (args .latent_w or pipe .unet .config .sample_size ) * 8 , # w
492+ )
493+
494+ noise_shape = (
495+ 1 , # B
496+ 4 , # C
497+ pipe .unet .config .sample_size , # H
498+ pipe .unet .config .sample_size , # w
499+ )
500+
501+ float_value_shape = (
502+ 1 ,
503+ 1 ,
504+ )
505+
506+ sqrt_alphas_cumprod_torch_shape = torch .tensor ([[0.2 ,]])
507+ sqrt_one_minus_alphas_cumprod_torch_shape = torch .tensor ([[0.8 ,]])
508+
509+ sample_vae_encoder_inputs = {
510+ "sample" : torch .rand (* sample_shape , dtype = torch .float16 ),
511+ "diagonal_noise" : torch .rand (* noise_shape , dtype = torch .float16 ),
512+ "noise" : torch .rand (* noise_shape , dtype = torch .float16 ),
513+ "sqrt_alphas_cumprod" : torch .rand (* float_value_shape , dtype = torch .float16 ),
514+ "sqrt_one_minus_alphas_cumprod" : torch .rand (* float_value_shape , dtype = torch .float16 ),
515+ }
516+
517+ class VAEEncoder (nn .Module ):
518+ """ Wrapper nn.Module wrapper for pipe.encode() method
519+ """
520+
521+ def __init__ (self ):
522+ super ().__init__ ()
523+ self .quant_conv = pipe .vae .quant_conv
524+ self .alphas_cumprod = pipe .scheduler .alphas_cumprod
525+ self .encoder = pipe .vae .encoder
526+
527+ # Because CoreMLTools does not support the torch.randn op, we pass in both
528+ # the diagonal Noise for the `DiagonalGaussianDistribution` operation and
529+ # the noise tensor combined with precalculated `sqrt_alphas_cumprod` and `sqrt_one_minus_alphas_cumprod`
530+ # for faster computation.
531+ def forward (self , sample , diagonal_noise , noise , sqrt_alphas_cumprod , sqrt_one_minus_alphas_cumprod ):
532+ h = self .encoder (sample )
533+ moments = self .quant_conv (h )
534+ posterior = CoreMLDiagonalGaussianDistribution (moments , diagonal_noise )
535+ posteriorSample = posterior .sample ()
536+
537+ # Add the scaling operation and the latent noise for faster computation
538+ init_latents = 0.18215 * posteriorSample
539+ result = self .add_noise (init_latents , noise , sqrt_alphas_cumprod , sqrt_one_minus_alphas_cumprod )
540+ return result
541+
542+ def add_noise (
543+ self ,
544+ original_samples : torch .FloatTensor ,
545+ noise : torch .FloatTensor ,
546+ sqrt_alphas_cumprod : torch .FloatTensor ,
547+ sqrt_one_minus_alphas_cumprod : torch .FloatTensor
548+ ) -> torch .FloatTensor :
549+ noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
550+ return noisy_samples
551+
552+
553+ baseline_encoder = VAEEncoder ().eval ()
554+
555+ # No optimization needed for the VAE Encoder as it is a pure ConvNet
556+ traced_vae_encoder = torch .jit .trace (
557+ baseline_encoder , (
558+ sample_vae_encoder_inputs ["sample" ].to (torch .float32 ),
559+ sample_vae_encoder_inputs ["diagonal_noise" ].to (torch .float32 ),
560+ sample_vae_encoder_inputs ["noise" ].to (torch .float32 ),
561+ sqrt_alphas_cumprod_torch_shape .to (torch .float32 ),
562+ sqrt_one_minus_alphas_cumprod_torch_shape .to (torch .float32 )
563+ ))
564+
565+ modify_coremltools_torch_frontend_badbmm ()
566+ coreml_vae_encoder , out_path = _convert_to_coreml (
567+ "vae_encoder" , traced_vae_encoder , sample_vae_encoder_inputs ,
568+ ["latent_dist" ], args )
569+
570+ # Set model metadata
571+ coreml_vae_encoder .author = f"Please refer to the Model Card available at huggingface.co/{ args .model_version } "
572+ coreml_vae_encoder .license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
573+ coreml_vae_encoder .version = args .model_version
574+ coreml_vae_encoder .short_description = \
575+ "Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
576+ "Please refer to https://arxiv.org/abs/2112.10752 for details."
577+
578+ # Set the input descriptions
579+ coreml_vae_encoder .input_description ["sample" ] = \
580+ "An image of the correct size to create the latent space with, image2image and in-painting."
581+ coreml_vae_encoder .input_description ["diagonal_noise" ] = \
582+ "Latent noise for `DiagonalGaussianDistribution` operation."
583+ coreml_vae_encoder .input_description ["noise" ] = \
584+ "Latent noise for use with strength parameter of image2image"
585+ coreml_vae_encoder .input_description ["sqrt_alphas_cumprod" ] = \
586+ "Precalculated `sqrt_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
587+ coreml_vae_encoder .input_description ["sqrt_one_minus_alphas_cumprod" ] = \
588+ "Precalculated `sqrt_one_minus_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
589+
590+ # Set the output descriptions
591+ coreml_vae_encoder .output_description [
592+ "latent_dist" ] = "The latent embeddings from the unet model from the input image."
593+
594+ _save_mlpackage (coreml_vae_encoder , out_path )
595+
596+ logger .info (f"Saved vae_encoder into { out_path } " )
597+
598+ # Parity check PyTorch vs CoreML
599+ if args .check_output_correctness :
600+ baseline_out = baseline_encoder (
601+ sample = sample_vae_encoder_inputs ["sample" ].to (torch .float32 ),
602+ diagonal_noise = sample_vae_encoder_inputs ["diagonal_noise" ].to (torch .float32 ),
603+ noise = sample_vae_encoder_inputs ["noise" ].to (torch .float32 ),
604+ sqrt_alphas_cumprod = sqrt_alphas_cumprod_torch_shape ,
605+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod_torch_shape ,
606+ ).numpy (),
607+
608+ coreml_out = list (
609+ coreml_vae_encoder .predict (
610+ {
611+ "sample" : sample_vae_encoder_inputs ["sample" ].numpy (),
612+ "diagonal_noise" : sample_vae_encoder_inputs ["diagonal_noise" ].numpy (),
613+ "noise" : sample_vae_encoder_inputs ["noise" ].numpy (),
614+ "sqrt_alphas_cumprod" : sqrt_alphas_cumprod_torch_shape .numpy (),
615+ "sqrt_one_minus_alphas_cumprod" : sqrt_one_minus_alphas_cumprod_torch_shape .numpy ()
616+ }).values ())
617+
618+ report_correctness (baseline_out [0 ], coreml_out [0 ],
619+ "vae_encoder baseline PyTorch to baseline CoreML" )
620+
621+ del traced_vae_encoder , pipe .vae .encoder , coreml_vae_encoder
622+ gc .collect ()
623+
624+
456625def convert_unet (pipe , args ):
457626 """ Converts the UNet component of Stable Diffusion
458627 """
@@ -801,7 +970,12 @@ def main(args):
801970 logger .info ("Converting vae_decoder" )
802971 convert_vae_decoder (pipe , args )
803972 logger .info ("Converted vae_decoder" )
804-
973+
974+ if args .convert_vae_encoder :
975+ logger .info ("Converting vae_encoder" )
976+ convert_vae_encoder (pipe , args )
977+ logger .info ("Converted vae_encoder" )
978+
805979 if args .convert_unet :
806980 logger .info ("Converting unet" )
807981 convert_unet (pipe , args )
@@ -835,6 +1009,7 @@ def parser_spec():
8351009 # Select which models to export (All are needed for text-to-image pipeline to function)
8361010 parser .add_argument ("--convert-text-encoder" , action = "store_true" )
8371011 parser .add_argument ("--convert-vae-decoder" , action = "store_true" )
1012+ parser .add_argument ("--convert-vae-encoder" , action = "store_true" )
8381013 parser .add_argument ("--convert-unet" , action = "store_true" )
8391014 parser .add_argument ("--convert-safety-checker" , action = "store_true" )
8401015 parser .add_argument (
0 commit comments