@torch.no_grad()
def latent2image(vae, latents: torch.Tensor) -> np.ndarray:
latents = 1 / 0.18215 * latents.detach()
image = vae.decode(latents)['sample']
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).astype(np.uint8)
return image
@torch.no_grad()
def image2latent(vae, image: np.ndarray) -> torch.Tensor:
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(device)
latents = vae.encode(image)['latent_dist'].sample()
latents = latents * 0.18215
return latents
@torch.no_grad()
def prompt2embeddings(tokenizer, text_encoder, prompt: str) -> torch.Tensor:
uncond_input = tokenizer(
[""], padding="max_length", max_length=tokenizer.model_max_length,
return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(pipe.device))[0]
text_input = tokenizer(
[prompt],
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids.to(pipe.device))[0]
return torch.cat([uncond_embeddings, text_embeddings])
@torch.no_grad()
def prompt_aware_adjustment(pipe, latents: torch.Tensor, prompt: str, guidance_scale: int = 7.5):
"""
Prompt-aware Adjustment, i.e., function h.
"""
encoder_hidden_states = prompt2embeddings(pipe.tokenizer, pipe.text_encoder, prompt)
# Add two steps of noise
noise = torch.randn_like(latents)
t = torch.tensor([2])
latents = pipe.scheduler.add_noise(latents, noise, t)
# Forward predict the noise steps
timesteps = [2, 1]
for t in timesteps:
latents_input = torch.cat([latents] * 2) # For CFG
noise_pred = pipe.unet(latents_input, t, encoder_hidden_states).sample
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
return latents