Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions docs/source/en/api/pipelines/pixart.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,111 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)

</Tip>

## Inference with under 8GB GPU VRAM

Run the [`PixArtAlphaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.

First, install the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library:

```bash
pip install -U bitsandbytes
```

Then load the text encoder in 8-bit:

```python
from transformers import T5EncoderModel
from diffusers import PixArtAlphaPipeline

text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",

)
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=text_encoder,
transformer=None,
device_map="auto"
)
```

Now, use the `pipe` to encode a prompt:

```python
with torch.no_grad():
prompt = "cute cat"
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)

del text_encoder
del pipe
flush()
```

Free up some GPU VRAM with the `flush()` utility function:

```python
import gc

def flush():
gc.collect()
torch.cuda.empty_cache()
```

Then compute the latents with the prompt embeddings as inputs:

```python
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")

latents = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
output_type="latent",
).images

del pipe.transformer
flush()
```

<Tip>

Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.

</Tip>

Once the latents are computed, pass it off to the VAE to decode into a real image:

```python
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")
image.save("cat.png")
```

By deleting components you aren't using and flushing the GPU VRAM, you should be able to run [`PixArtAlphaPipeline`] with under 8GB GPU VRAM.

![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png)

If you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e).

<Tip warning={true}>

Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.

</Tip>

While loading the `text_encoder`, you set `load_in_8bit` to `True`. Instead of `load_in_8bit`, if you specify `load_in_4bit`, the memory requirements will be down under 7GB.

## PixArtAlphaPipeline

[[autodoc]] PixArtAlphaPipeline
Expand Down