In this notebook, we show how to fine-tune Stable Diffusion XL (SDXL) with DreamBooth and LoRA using some of the most popular SOTA methods.
Learn more about the techniques used in this exmaple [here](linke to blogpost)
Let's get started 🧪
# Install dependencies.
!pip install xformers bitsandbytes transformers accelerate wandb dadaptation prodigyopt -q
WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) DEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of distro-info or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 DEPRECATION: python-debian 0.1.36ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of python-debian or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 [notice] A new release of pip is available: 23.2 -> 23.3.1 [notice] To update, run: python3.10 -m pip install --upgrade pip
!pip install peft -q
WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) DEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of distro-info or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 DEPRECATION: python-debian 0.1.36ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of python-debian or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 [notice] A new release of pip is available: 23.2 -> 23.3.1 [notice] To update, run: python3.10 -m pip install --upgrade pip
Make sure to install diffusers
from main
.
!pip install git+https://github.com/huggingface/diffusers.git -q
WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) DEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of distro-info or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 DEPRECATION: python-debian 0.1.36ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of python-debian or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 [notice] A new release of pip is available: 23.2 -> 23.3.1 [notice] To update, run: python3.10 -m pip install --upgrade pip
Download diffusers SDXL DreamBooth training script.
!wget https://raw.githubusercontent.com/huggingface/diffusers/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
--2023-12-04 08:52:41-- https://raw.githubusercontent.com/huggingface/diffusers/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 86471 (84K) [text/plain] Saving to: ‘train_dreambooth_lora_sdxl_advanced.py.1’ train_dreambooth_lo 100%[===================>] 84.44K --.-KB/s in 0.001s 2023-12-04 08:52:41 (111 MB/s) - ‘train_dreambooth_lora_sdxl_advanced.py.1’ saved [86471/86471]
Let's get our training data! For this example, we'll download some images from the hub.
If you already have a dataset on the hub you wish to use, you can skip this part and go straight to: "Prep for training 💻" section, where you'll simply specify the dataset name.
If your images are saved locally, and/or you want to add BLIP generated captions, pick option 1 or 2 below.
Option 1: upload example images from your local files:
import os
from google.colab import files
# pick a name for the image folder
local_dir = "./my_folder" #@param
os.makedirs(local_dir)
os.chdir(local_dir)
# choose and upload local images into the newly created directory
uploaded_images = files.upload()
os.chdir("/content") # back to parent directory
Option 2: download example images from the hub -
from huggingface_hub import snapshot_download
local_dir = "./3d_icon" #@param
dataset_to_download = "LinoyTsaban/3d_icon" #@param
snapshot_download(
dataset_to_download,
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
Fetching 24 files: 0%| | 0/24 [00:00<?, ?it/s]
'/home/ubuntu/testing v2/3d_icon'
Preview the images:
from PIL import Image
def image_grid(imgs, rows, cols, resize=256):
assert len(imgs) == rows * cols
if resize is not None:
imgs = [img.resize((resize, resize)) for img in imgs]
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
import glob
local_dir = "./3d_icon"
img_paths = f"{local_dir}/*.jpg"
imgs = [Image.open(path) for path in glob.glob(img_paths)]
num_imgs_to_preview = 5
image_grid(imgs[:num_imgs_to_preview], 1, num_imgs_to_preview)
/home/ubuntu/.local/lib/python3.10/site-packages/PIL/Image.py:3182: DecompressionBombWarning: Image size (122880000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. warnings.warn( /home/ubuntu/.local/lib/python3.10/site-packages/PIL/Image.py:3182: DecompressionBombWarning: Image size (132710400 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. warnings.warn(
Load BLIP2 to auto caption your images:
Note: if you downloaded the LinoyTsaban/3d_icon dataset
from the hub, you would find it already contains captions (generated with BLIP and prefixed with a token identifier) in the metadata.jsonl
file
You can skip this part if you wish to train on that dataset using the existing captions.
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# load pipelines
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",torch_dtype=torch.float16).to(device)
## IMAGE CPATIONING ##
def caption_images(input_image):
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
pixel_values = inputs.pixel_values
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
2023-11-26 18:54:29.686931: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-11-26 18:54:30.425056: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
import glob
from PIL import Image
# create a list of (Pil.Image, path) pairs
local_dir = "./3d_icon/"
imgs_and_paths = [(path,Image.open(path)) for path in glob.glob(f"{local_dir}*.jpg")]
/home/ubuntu/.local/lib/python3.10/site-packages/PIL/Image.py:3182: DecompressionBombWarning: Image size (122880000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. warnings.warn( /home/ubuntu/.local/lib/python3.10/site-packages/PIL/Image.py:3182: DecompressionBombWarning: Image size (132710400 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. warnings.warn(
Now let's add the concept token identifier (e.g. TOK) to each caption using a caption prefix.
Note: When training with pivotal tuning, this token identifier (e.g. TOK) is only a place holder, and will be mapped to new tokens we insert to the tokenizers - so no need to spend too much time choosing the token!
Change the prefix according to the concept you're training on:
-- e.g. for this example, instead of "In the style of TOK" we can use "3d icon in the style of TOK"/"a TOK 3d style icon"
saves image paths and corresponding prompts to metadata file for training
import json
from IPython.display import display, Markdown
caption_prefix = "3d icon in the style of TOK, " #@param
# saves each caption and corresponding image to a metadata.jsonl file
with open(f'{local_dir}metadata.jsonl', 'w') as outfile:
for img in imgs_and_paths:
caption = caption_prefix + caption_images(img[1]).split("\n")[0]
entry = {"file_name":img[0].split("/")[-1], "prompt": caption}
json.dump(entry, outfile)
outfile.write('\n')
display(Markdown(f"Your image captions are ready here: {local_dir}metadata.jsonl"))
Your image captions are ready here: ./3d_icon/metadata.jsonl
Free some memory:
import gc
# delete the BLIP2 pipelines and clear up some memory
del blip_processor, blip_model
gc.collect()
torch.cuda.empty_cache()
Initialize accelerate
:
!accelerate config default
Configuration already exists at /home/ubuntu/.cache/huggingface/accelerate/default_config.yaml, will not override. Run `accelerate config` manually or pass a different `save_location`.
Pass your write access token so that we can push the trained checkpoints to the Hugging Face Hub:
from huggingface_hub import notebook_login
notebook_login()
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
Diffusers
🧨 Training loop¶How to choose your hyperparams? Check out this blog post - where we explore and comapre different hyperparmas and configurations for different use cases, depending on your data and subject.
Make sure to add push_to_hub
so that the checkpoint is automatically pushed to the Hub and doesn't get lost. The --push_to_hub
argument ensures that the trained checkpoints are automatically pushed to the Hugging Face Hub.
Some paramters that can help us with compute when doing DreamBooth with LoRA on a heavy pipeline like Stable Diffusion XL:
--gradient_accumulation_steps
)--use_8bit_adam
) - optional when using --optimizer='AdamW'
, with --optimizer='Prodigy'
this will be ignored--mixed-precision="bf16"
)To allow for custom captions we need to install the datasets
library:
--caption_column
to specify name of the cpation column in your dataset."prompt"
to
save our captions in the metadata file, change this according to your needs.Otherwise:
--instance_prompt
.
in that case, specify --instance_data_dir
instead of --dataset_name
# makes sure we install datasets from main
!pip install git+https://github.com/huggingface/datasets.git -q
WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) DEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of distro-info or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 DEPRECATION: python-debian 0.1.36ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of python-debian or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 [notice] A new release of pip is available: 23.2 -> 23.3.1 [notice] To update, run: python3.10 -m pip install --upgrade pip
This name will be used to save your model, so pick an informative name based on your chosen concept💡
!pip install python-slugify
from slugify import slugify
model_name = "3d icon SDXL LoRA" # @param
output_dir = slugify(model_name)
Defaulting to user installation because normal site-packages is not writeable WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) Looking in indexes: https://pypi.org/simple/ Requirement already satisfied: python-slugify in /home/ubuntu/.local/lib/python3.10/site-packages (8.0.1) Requirement already satisfied: text-unidecode>=1.3 in /home/ubuntu/.local/lib/python3.10/site-packages (from python-slugify) (1.3) WARNING: Ignoring invalid distribution -etworkx (/usr/lib/python3/dist-packages) DEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of distro-info or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 DEPRECATION: python-debian 0.1.36ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of python-debian or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 [notice] A new release of pip is available: 23.2 -> 23.3.1 [notice] To update, run: python3.10 -m pip install --upgrade pip
Instance & Validation Prompt
instance_prompt
-validation_prompt
-num_validation_images
(4 by default) and validation_epochs
(50 by default) to control the amount images generated with the validation prompt, and the number of ephochs between each dreambooth validation.instance_prompt = "3d icon in the style of TOK" # @param
validation_prompt = "a TOK icon of an astronaut riding a horse, in the style of TOK" # @param
Set your LoRA rank The rank of your LoRA is linked to its expressiveness. The bigger the rank the closer we are to regular dreambooth, and in theory we have more expressive power (and heavier weights).
For a very simple concept that you have a good high quality image set for (e.g. a pet, a generic object), a rank as low as 4 can be enough to get great results. We reccomend going between 8 and 64 depending on your concept and how much of a priortiy it is for you to keep the LoRA small or not.
rank = 8 # @param
#!/usr/bin/env bash
!accelerate launch train_dreambooth_lora_sdxl_advanced.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
--dataset_name="./3d_icon" \
--instance_prompt="$instance_prompt" \
--validation_prompt="$validation_prompt" \
--output_dir="$output_dir" \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=3 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--adam_beta2=0.99 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--snr_gamma=5.0 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank="$rank" \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
2023-12-04 08:53:29.479966: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-12-04 08:53:30.305590: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 12/04/2023 08:53:32 - INFO - __main__ - Distributed environment: NO Num processes: 1 Process index: 0 Local process index: 0 Device: cuda Mixed precision type: bf16 You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors. You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors. {'dynamic_thresholding_ratio', 'variance_type', 'thresholding', 'clip_sample_range'} was not found in config. Values will be initialized to default values. {'attention_type', 'dropout', 'reverse_transformer_layers_per_block'} was not found in config. Values will be initialized to default values. 0 text encodedr's std_token_embedding: 0.015381863340735435 torch.Size([49410]) 1 text encodedr's std_token_embedding: 0.014391135424375534 torch.Size([49410]) Using decoupled weight decay Resolving data files: 100%|█████████████████| 24/24 [00:00<00:00, 169466.83it/s] 12/04/2023 08:53:47 - WARNING - datasets.builder - Found cached dataset imagefolder (/home/ubuntu/.cache/huggingface/datasets/imagefolder/3d_icon-1d8dc1619b3c57b9/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f) 100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 759.70it/s] /home/ubuntu/.local/lib/python3.10/site-packages/PIL/Image.py:3182: DecompressionBombWarning: Image size (122880000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. warnings.warn( /home/ubuntu/.local/lib/python3.10/site-packages/PIL/Image.py:3182: DecompressionBombWarning: Image size (132710400 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. warnings.warn( validation prompt: a <s0><s1> icon of an astronaut riding a horse, in the style of <s0><s1> wandb: Currently logged in as: linoy. Use `wandb login --relogin` to force relogin wandb: wandb version 0.16.0 is available! To upgrade, please run: wandb: $ pip install wandb --upgrade wandb: Tracking run with wandb version 0.15.0 wandb: Run data is saved locally in /home/ubuntu/testing v2/wandb/run-20231204_085353-so9s09mr wandb: Run `wandb offline` to turn off syncing. wandb: Syncing run resilient-energy-474 wandb: ⭐️ View project at https://wandb.ai/linoy/dreambooth-lora-sd-xl wandb: 🚀 View run at https://wandb.ai/linoy/dreambooth-lora-sd-xl/runs/so9s09mr 12/04/2023 08:53:53 - INFO - __main__ - ***** Running training ***** 12/04/2023 08:53:53 - INFO - __main__ - Num examples = 23 12/04/2023 08:53:53 - INFO - __main__ - Num batches each epoch = 8 12/04/2023 08:53:53 - INFO - __main__ - Num Epochs = 125 12/04/2023 08:53:53 - INFO - __main__ - Instantaneous batch size per device = 3 12/04/2023 08:53:53 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 3 12/04/2023 08:53:53 - INFO - __main__ - Gradient Accumulation steps = 1 12/04/2023 08:53:53 - INFO - __main__ - Total optimization steps = 1000 Steps: 1%|▏ | 8/1000 [00:20<37:50, 2.29s/it, loss=0.094, lr=1]12/04/2023 08:54:14 - INFO - __main__ - Running validation... Generating 4 images with prompt: a <s0><s1> icon of an astronaut riding a horse, in the style of <s0><s1>. {'feature_extractor', 'image_encoder'} was not found in config. Values will be initialized to default values. Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s] Loading pipeline components...: 14%|█▊ | 1/7 [00:00<00:03, 1.58it/s]Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 57%|███████▍ | 4/7 [00:00<00:00, 6.77it/s]{'sigma_max', 'sigma_min', 'timestep_type'} was not found in config. Values will be initialized to default values. Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 100%|█████████████| 7/7 [00:00<00:00, 9.46it/s] {'euler_at_final', 'lambda_min_clipped', 'lower_order_final', 'solver_type', 'variance_type', 'algorithm_type', 'use_lu_lambdas', 'dynamic_thresholding_ratio', 'thresholding', 'solver_order'} was not found in config. Values will be initialized to default values. Steps: 41%|█████▋ | 408/1000 [17:31<22:28, 2.28s/it, loss=0.0873, lr=1]12/04/2023 09:11:25 - INFO - __main__ - Running validation... Generating 4 images with prompt: a <s0><s1> icon of an astronaut riding a horse, in the style of <s0><s1>. {'feature_extractor', 'image_encoder'} was not found in config. Values will be initialized to default values. Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0. {'sigma_max', 'sigma_min', 'timestep_type'} was not found in config. Values will be initialized to default values. Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 100%|█████████████| 7/7 [00:00<00:00, 70.16it/s] {'euler_at_final', 'lambda_min_clipped', 'lower_order_final', 'solver_type', 'variance_type', 'algorithm_type', 'use_lu_lambdas', 'dynamic_thresholding_ratio', 'thresholding', 'solver_order'} was not found in config. Values will be initialized to default values. Steps: 50%|██████▉ | 496/1000 [21:51<19:24, 2.31s/it, loss=0.0483, lr=1]PIVOT HALFWAY 62 Steps: 81%|███████████▎ | 808/1000 [34:41<07:21, 2.30s/it, loss=0.0012, lr=1]12/04/2023 09:28:34 - INFO - __main__ - Running validation... Generating 4 images with prompt: a <s0><s1> icon of an astronaut riding a horse, in the style of <s0><s1>. {'feature_extractor', 'image_encoder'} was not found in config. Values will be initialized to default values. Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0. {'sigma_max', 'sigma_min', 'timestep_type'} was not found in config. Values will be initialized to default values. Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 100%|█████████████| 7/7 [00:00<00:00, 70.87it/s] {'euler_at_final', 'lambda_min_clipped', 'lower_order_final', 'solver_type', 'variance_type', 'algorithm_type', 'use_lu_lambdas', 'dynamic_thresholding_ratio', 'thresholding', 'solver_order'} was not found in config. Values will be initialized to default values. Steps: 100%|█████████████| 1000/1000 [43:15<00:00, 2.27s/it, loss=0.0651, lr=1]Model weights saved in 3d-icon-sdxl-lora/pytorch_lora_weights.safetensors {'feature_extractor', 'image_encoder'} was not found in config. Values will be initialized to default values. Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]{'attention_type', 'dropout', 'reverse_transformer_layers_per_block'} was not found in config. Values will be initialized to default values. Loaded unet as UNet2DConditionModel from `unet` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 14%|█▊ | 1/7 [00:02<00:17, 2.87s/it]Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loaded text_encoder_2 as CLIPTextModelWithProjection from `text_encoder_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 43%|█████▌ | 3/7 [00:04<00:04, 1.18s/it]Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 71%|█████████▎ | 5/7 [00:04<00:01, 1.44it/s]{'sigma_max', 'sigma_min', 'timestep_type'} was not found in config. Values will be initialized to default values. Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 100%|█████████████| 7/7 [00:04<00:00, 1.55it/s] {'euler_at_final', 'lambda_min_clipped', 'lower_order_final', 'solver_type', 'variance_type', 'algorithm_type', 'use_lu_lambdas', 'dynamic_thresholding_ratio', 'thresholding', 'solver_order'} was not found in config. Values will be initialized to default values. Loading unet. 0%| | 0/25 [00:00<?, ?it/s] 4%|█▊ | 1/25 [00:00<00:03, 6.93it/s] 8%|███▌ | 2/25 [00:00<00:02, 7.68it/s] 12%|█████▎ | 3/25 [00:00<00:02, 7.92it/s] 16%|███████ | 4/25 [00:00<00:02, 7.72it/s] 20%|████████▊ | 5/25 [00:00<00:02, 7.62it/s] 24%|██████████▌ | 6/25 [00:00<00:02, 7.56it/s] 28%|████████████▎ | 7/25 [00:00<00:02, 7.52it/s] 32%|██████████████ | 8/25 [00:01<00:02, 7.50it/s] 36%|███████████████▊ | 9/25 [00:01<00:02, 7.45it/s] 40%|█████████████████▏ | 10/25 [00:01<00:02, 7.46it/s] 44%|██████████████████▉ | 11/25 [00:01<00:01, 7.45it/s] 48%|████████████████████▋ | 12/25 [00:01<00:01, 7.45it/s] 52%|██████████████████████▎ | 13/25 [00:01<00:01, 7.45it/s] 56%|████████████████████████ | 14/25 [00:01<00:01, 7.45it/s] 60%|█████████████████████████▊ | 15/25 [00:01<00:01, 7.45it/s] 64%|███████████████████████████▌ | 16/25 [00:02<00:01, 7.44it/s] 68%|█████████████████████████████▏ | 17/25 [00:02<00:01, 7.43it/s] 72%|██████████████████████████████▉ | 18/25 [00:02<00:00, 7.43it/s] 76%|████████████████████████████████▋ | 19/25 [00:02<00:00, 7.43it/s] 80%|██████████████████████████████████▍ | 20/25 [00:02<00:00, 7.43it/s] 84%|████████████████████████████████████ | 21/25 [00:02<00:00, 7.42it/s] 88%|█████████████████████████████████████▊ | 22/25 [00:02<00:00, 7.43it/s] 92%|███████████████████████████████████████▌ | 23/25 [00:03<00:00, 7.43it/s] 96%|█████████████████████████████████████████▎ | 24/25 [00:03<00:00, 7.43it/s] 100%|███████████████████████████████████████████| 25/25 [00:03<00:00, 7.47it/s] 0%| | 0/25 [00:00<?, ?it/s] 4%|█▊ | 1/25 [00:00<00:03, 6.96it/s] 8%|███▌ | 2/25 [00:00<00:02, 7.68it/s] 12%|█████▎ | 3/25 [00:00<00:02, 7.92it/s] 16%|███████ | 4/25 [00:00<00:02, 7.71it/s] 20%|████████▊ | 5/25 [00:00<00:02, 7.61it/s] 24%|██████████▌ | 6/25 [00:00<00:02, 7.55it/s] 28%|████████████▎ | 7/25 [00:00<00:02, 7.52it/s] 32%|██████████████ | 8/25 [00:01<00:02, 7.48it/s] 36%|███████████████▊ | 9/25 [00:01<00:02, 7.47it/s] 40%|█████████████████▏ | 10/25 [00:01<00:02, 7.46it/s] 44%|██████████████████▉ | 11/25 [00:01<00:01, 7.46it/s] 48%|████████████████████▋ | 12/25 [00:01<00:01, 7.32it/s] 52%|██████████████████████▎ | 13/25 [00:01<00:01, 7.48it/s] 56%|████████████████████████ | 14/25 [00:01<00:01, 7.49it/s] 60%|█████████████████████████▊ | 15/25 [00:01<00:01, 7.47it/s] 64%|███████████████████████████▌ | 16/25 [00:02<00:01, 7.47it/s] 68%|█████████████████████████████▏ | 17/25 [00:02<00:01, 7.47it/s] 72%|██████████████████████████████▉ | 18/25 [00:02<00:00, 7.46it/s] 76%|████████████████████████████████▋ | 19/25 [00:02<00:00, 7.46it/s] 80%|██████████████████████████████████▍ | 20/25 [00:02<00:00, 7.44it/s] 84%|████████████████████████████████████ | 21/25 [00:02<00:00, 7.43it/s] 88%|█████████████████████████████████████▊ | 22/25 [00:02<00:00, 7.44it/s] 92%|███████████████████████████████████████▌ | 23/25 [00:03<00:00, 7.44it/s] 96%|█████████████████████████████████████████▎ | 24/25 [00:03<00:00, 7.44it/s] 100%|███████████████████████████████████████████| 25/25 [00:03<00:00, 7.48it/s] 0%| | 0/25 [00:00<?, ?it/s] 4%|█▊ | 1/25 [00:00<00:03, 6.97it/s] 8%|███▌ | 2/25 [00:00<00:02, 7.69it/s] 12%|█████▎ | 3/25 [00:00<00:02, 7.91it/s] 16%|███████ | 4/25 [00:00<00:02, 7.72it/s] 20%|████████▊ | 5/25 [00:00<00:02, 7.61it/s] 24%|██████████▌ | 6/25 [00:00<00:02, 7.53it/s] 28%|████████████▎ | 7/25 [00:00<00:02, 7.49it/s] 32%|██████████████ | 8/25 [00:01<00:02, 7.48it/s] 36%|███████████████▊ | 9/25 [00:01<00:02, 7.48it/s] 40%|█████████████████▏ | 10/25 [00:01<00:02, 7.46it/s] 44%|██████████████████▉ | 11/25 [00:01<00:01, 7.46it/s] 48%|████████████████████▋ | 12/25 [00:01<00:01, 7.45it/s] 52%|██████████████████████▎ | 13/25 [00:01<00:01, 7.45it/s] 56%|████████████████████████ | 14/25 [00:01<00:01, 7.45it/s] 60%|█████████████████████████▊ | 15/25 [00:02<00:01, 7.44it/s] 64%|███████████████████████████▌ | 16/25 [00:02<00:01, 7.45it/s] 68%|█████████████████████████████▏ | 17/25 [00:02<00:01, 7.43it/s] 72%|██████████████████████████████▉ | 18/25 [00:02<00:00, 7.44it/s] 76%|████████████████████████████████▋ | 19/25 [00:02<00:00, 7.44it/s] 80%|██████████████████████████████████▍ | 20/25 [00:02<00:00, 7.44it/s] 84%|████████████████████████████████████ | 21/25 [00:02<00:00, 7.44it/s] 88%|█████████████████████████████████████▊ | 22/25 [00:02<00:00, 7.45it/s] 92%|███████████████████████████████████████▌ | 23/25 [00:03<00:00, 7.45it/s] 96%|█████████████████████████████████████████▎ | 24/25 [00:03<00:00, 7.44it/s] 100%|███████████████████████████████████████████| 25/25 [00:03<00:00, 7.48it/s] 0%| | 0/25 [00:00<?, ?it/s] 4%|█▊ | 1/25 [00:00<00:03, 6.95it/s] 8%|███▌ | 2/25 [00:00<00:02, 7.70it/s] 12%|█████▎ | 3/25 [00:00<00:02, 7.94it/s] 16%|███████ | 4/25 [00:00<00:02, 7.73it/s] 20%|████████▊ | 5/25 [00:00<00:02, 7.62it/s] 24%|██████████▌ | 6/25 [00:00<00:02, 7.56it/s] 28%|████████████▎ | 7/25 [00:00<00:02, 7.52it/s] 32%|██████████████ | 8/25 [00:01<00:02, 7.50it/s] 36%|███████████████▊ | 9/25 [00:01<00:02, 7.49it/s] 40%|█████████████████▏ | 10/25 [00:01<00:02, 7.47it/s] 44%|██████████████████▉ | 11/25 [00:01<00:01, 7.46it/s] 48%|████████████████████▋ | 12/25 [00:01<00:01, 7.45it/s] 52%|██████████████████████▎ | 13/25 [00:01<00:01, 7.44it/s] 56%|████████████████████████ | 14/25 [00:01<00:01, 7.45it/s] 60%|█████████████████████████▊ | 15/25 [00:01<00:01, 7.42it/s] 64%|███████████████████████████▌ | 16/25 [00:02<00:01, 7.45it/s] 68%|█████████████████████████████▏ | 17/25 [00:02<00:01, 7.46it/s] 72%|██████████████████████████████▉ | 18/25 [00:02<00:00, 7.46it/s] 76%|████████████████████████████████▋ | 19/25 [00:02<00:00, 7.46it/s] 80%|██████████████████████████████████▍ | 20/25 [00:02<00:00, 7.46it/s] 84%|████████████████████████████████████ | 21/25 [00:02<00:00, 7.45it/s] 88%|█████████████████████████████████████▊ | 22/25 [00:02<00:00, 7.46it/s] 92%|███████████████████████████████████████▌ | 23/25 [00:03<00:00, 7.45it/s] 96%|█████████████████████████████████████████▎ | 24/25 [00:03<00:00, 7.45it/s] 100%|███████████████████████████████████████████| 25/25 [00:03<00:00, 7.48it/s] embeddings.safetensors: 0%| | 0.00/8.34k [00:00<?, ?B/s] image_0.png: 0%| | 0.00/1.19M [00:00<?, ?B/s] image_1.png: 0%| | 0.00/1.10M [00:00<?, ?B/s] Upload 6 LFS files: 0%| | 0/6 [00:00<?, ?it/s] image_3.png: 0%| | 0.00/1.15M [00:00<?, ?B/s] embeddings.safetensors: 100%|███████████████| 8.34k/8.34k [00:00<00:00, 116kB/s] image_3.png: 100%|█████████████████████████| 1.15M/1.15M [00:00<00:00, 8.86MB/s] image_2.png: 100%|█████████████████████████| 1.13M/1.13M [00:00<00:00, 8.42MB/s] Upload 6 LFS files: 17%|████▏ | 1/6 [00:00<00:00, 6.26it/s] image_1.png: 100%|█████████████████████████| 1.10M/1.10M [00:00<00:00, 4.84MB/s] image_0.png: 100%|█████████████████████████| 1.19M/1.19M [00:00<00:00, 5.05MB/s] Upload 6 LFS files: 33%|████████▎ | 2/6 [00:00<00:00, 4.25it/s] pytorch_lora_weights.safetensors: 34%|█▎ | 16.0M/46.6M [00:00<00:00, 40.8MB/s] pytorch_lora_weights.safetensors: 100%|████| 46.6M/46.6M [00:00<00:00, 48.0MB/s] Upload 6 LFS files: 100%|█████████████████████████| 6/6 [00:01<00:00, 4.94it/s] wandb: Waiting for W&B process to finish... (success). wandb: wandb: Run history: wandb: loss ▅▁▂▄▅▅▁█▂▄▂█▁▂▆▃▄▃▃▄▁▂█▃▅▃▅▃▄▅▅▃▂▂▁▄▅▂▇▁ wandb: lr ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ wandb: wandb: Run summary: wandb: loss 0.0651 wandb: lr 1.0 wandb: wandb: 🚀 View run resilient-energy-474 at: https://wandb.ai/linoy/dreambooth-lora-sd-xl/runs/so9s09mr wandb: Synced 5 W&B file(s), 16 media file(s), 0 artifact file(s) and 0 other file(s) wandb: Find logs at: ./wandb/run-20231204_085353-so9s09mr/logs /home/ubuntu/.local/lib/python3.10/site-packages/wandb/sdk/wandb_run.py:2087: UserWarning: Run (so9s09mr) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run. lambda data: self._console_raw_callback("stderr", data), Steps: 100%|█████████████| 1000/1000 [43:49<00:00, 2.63s/it, loss=0.0651, lr=1]
from huggingface_hub import whoami
from pathlib import Path
from IPython.display import display, Markdown
username = whoami(token=Path("/root/.cache/huggingface/"))["name"]
repo_id = f"{username}/{output_dir}"
link_to_model = f"https://huggingface.co/{repo_id}"
display(Markdown("### Your model has finished training.\nAccess it here: {}".format(link_to_model)))
Access it here: https://huggingface.co/LinoyTsaban/3d-icon-sdxl-lora
import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL
from safetensors.torch import load_file
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
2023-12-04 09:37:49.371395: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-12-04 09:37:50.221671: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]
pytorch_lora_weights.safetensors: 0%| | 0.00/46.6M [00:00<?, ?B/s]
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_id, filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
embeddings.safetensors: 0%| | 0.00/8.34k [00:00<?, ?B/s]
instance_token = "<s0><s1>"
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
image
0%| | 0/25 [00:00<?, ?it/s]