#!/usr/bin/env python # coding: utf-8 # In[1]: import torch import numpy as np import matplotlib.pyplot as plt from diffusers import StableDiffusionPipeline, DDIMScheduler # In[2]: cache_dir = "" repo_id = "stabilityai/stable-diffusion-2-1-base" scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') pipe = StableDiffusionPipeline.from_pretrained(repo_id, cache_dir=cache_dir, scheduler=scheduler).to(device) # In[3]: @torch.no_grad() def latent2image(vae, latents: torch.Tensor) -> np.ndarray: latents = 1 / 0.18215 * latents.detach() image = vae.decode(latents)['sample'] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).astype(np.uint8) return image @torch.no_grad() def image2latent(vae, image: np.ndarray) -> torch.Tensor: image = torch.from_numpy(image).float() / 127.5 - 1 image = image.permute(2, 0, 1).unsqueeze(0).to(device) latents = vae.encode(image)['latent_dist'].sample() latents = latents * 0.18215 return latents @torch.no_grad() def prompt2embeddings(tokenizer, text_encoder, prompt: str) -> torch.Tensor: uncond_input = tokenizer( [""], padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(pipe.device))[0] text_input = tokenizer( [prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = text_encoder(text_input.input_ids.to(pipe.device))[0] return torch.cat([uncond_embeddings, text_embeddings]) @torch.no_grad() def prompt_aware_adjustment(pipe, latents: torch.Tensor, prompt: str, guidance_scale: int = 7.5): """ Prompt-aware Adjustment, i.e., function h. """ encoder_hidden_states = prompt2embeddings(pipe.tokenizer, pipe.text_encoder, prompt) # Add two steps of noise noise = torch.randn_like(latents) t = torch.tensor([2]) latents = pipe.scheduler.add_noise(latents, noise, t) # Forward predict the noise steps timesteps = [2, 1] for t in timesteps: latents_input = torch.cat([latents] * 2) # For CFG noise_pred = pipe.unet(latents_input, t, encoder_hidden_states).sample noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample return latents # In[4]: torch.manual_seed(3) zt = torch.randn(1, 4, 64, 64) with torch.no_grad(): z0 = pipe("a cute dog", latents=zt, output_type="latent")[0] # In[5]: # Only taking steps of size 1 pipe.scheduler.set_timesteps(1000) # In[6]: img_x = latent2image(pipe.vae, z0) z0e = image2latent(pipe.vae, img_x) z0tilde = prompt_aware_adjustment(pipe, z0e, "a cute dog", guidance_scale=7.5) # In[7]: # Error between latent generated image and latent image after passing to decoder-encoder, in paper this is 16 ((z0 - z0e) ** 2).sum() # In[8]: # Error between latent generated image and latent image after passing to decoder-encoder + prompt-aware adjustment # In paper this is 12 ((z0 - z0tilde) ** 2).sum() #This should be lower than the previous error # In[9]: fig, ax = plt.subplots(1, 3, figsize=(16, 8)) ax[0].imshow(latent2image(pipe.vae, z0)) ax[1].imshow(latent2image(pipe.vae, z0e)) ax[2].imshow(latent2image(pipe.vae, z0tilde)) plt.show()