Skip to content

Commit 431df66

Browse files
authored
Revert "Other Acceleration tricks (#93)" (#94)
This reverts commit eacf501.
1 parent eacf501 commit 431df66

File tree

4 files changed

+16
-128
lines changed

4 files changed

+16
-128
lines changed

lora_diffusion/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from .lora import *
22
from .dataset import *
3-
from .utils import *

lora_diffusion/cli_lora_pti.py

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -142,26 +142,21 @@ def collate_fn(examples):
142142
"input_ids": input_ids,
143143
"pixel_values": pixel_values,
144144
}
145-
146-
if examples[0].get("mask", None) is not None:
147-
batch["mask"] = torch.stack([example["mask"] for example in examples])
148-
149145
return batch
150146

151147
train_dataloader = torch.utils.data.DataLoader(
152148
train_dataset,
153149
batch_size=train_batch_size,
154150
shuffle=True,
155151
collate_fn=collate_fn,
152+
num_workers=2,
156153
)
157154

158155
return train_dataloader
159156

160157

161158
@torch.autocast("cuda")
162-
def loss_step(
163-
batch, unet, vae, text_encoder, scheduler, weight_dtype, t_mutliplier=1.0
164-
):
159+
def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype):
165160
latents = vae.encode(
166161
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
167162
).latent_dist.sample()
@@ -172,7 +167,7 @@ def loss_step(
172167

173168
timesteps = torch.randint(
174169
0,
175-
int(scheduler.config.num_train_timesteps * t_mutliplier),
170+
scheduler.config.num_train_timesteps,
176171
(bsz,),
177172
device=latents.device,
178173
)
@@ -191,31 +186,6 @@ def loss_step(
191186
else:
192187
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
193188

194-
if batch.get("mask", None) is not None:
195-
196-
mask = (
197-
batch["mask"]
198-
.to(model_pred.device)
199-
.reshape(
200-
model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
201-
)
202-
)
203-
# resize to match model_pred
204-
mask = (
205-
F.interpolate(
206-
mask.float(),
207-
size=model_pred.shape[-2:],
208-
mode="nearest",
209-
)
210-
+ 0.1
211-
)
212-
213-
mask = mask / mask.mean()
214-
215-
model_pred = model_pred * mask
216-
217-
target = target * mask
218-
219189
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
220190
return loss
221191

@@ -303,15 +273,7 @@ def perform_tuning(
303273
for batch in dataloader:
304274
optimizer.zero_grad()
305275

306-
loss = loss_step(
307-
batch,
308-
unet,
309-
vae,
310-
text_encoder,
311-
scheduler,
312-
weight_dtype,
313-
t_mutliplier=0.8,
314-
)
276+
loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype)
315277
loss.backward()
316278
torch.nn.utils.clip_grad_norm_(
317279
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
@@ -360,7 +322,7 @@ def train(
360322
class_data_dir: Optional[str] = None,
361323
stochastic_attribute: Optional[str] = None,
362324
perform_inversion: bool = True,
363-
use_template: Literal[None, "object", "style"] = None,
325+
use_template: Optional[str] = Literal[None, "object", "style"],
364326
placeholder_tokens: str = "<s>",
365327
placeholder_token_at_data: Optional[str] = None,
366328
initializer_tokens: str = "dog",
@@ -370,6 +332,7 @@ def train(
370332
num_class_images: int = 100,
371333
seed: int = 42,
372334
resolution: int = 512,
335+
center_crop: bool = False,
373336
color_jitter: bool = True,
374337
train_batch_size: int = 1,
375338
sample_batch_size: int = 1,
@@ -387,7 +350,6 @@ def train(
387350
learning_rate_ti: float = 5e-4,
388351
continue_inversion: bool = True,
389352
continue_inversion_lr: Optional[float] = None,
390-
use_face_segmentation_condition: bool = False,
391353
scale_lr: bool = False,
392354
lr_scheduler: str = "constant",
393355
lr_warmup_steps: int = 100,
@@ -451,8 +413,8 @@ def train(
451413
class_prompt=class_prompt,
452414
tokenizer=tokenizer,
453415
size=resolution,
416+
center_crop=center_crop,
454417
color_jitter=color_jitter,
455-
use_face_segmentation_condition=use_face_segmentation_condition,
456418
)
457419

458420
train_dataloader = text2img_dataloader(

lora_diffusion/dataset.py

Lines changed: 9 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from torch.utils.data import Dataset
22

33
from typing import List, Tuple, Dict, Union, Optional
4-
from PIL import Image, ImageFilter
4+
from PIL import Image
55
from torchvision import transforms
66
from pathlib import Path
7-
import cv2
7+
88
import random
9-
import numpy as np
109

1110
OBJECT_TEMPLATE = [
1211
"a photo of a {}",
@@ -91,12 +90,12 @@ def __init__(
9190
class_prompt=None,
9291
size=512,
9392
h_flip=True,
93+
center_crop=False,
9494
color_jitter=False,
9595
resize=True,
96-
use_face_segmentation_condition=False,
97-
blur_amount: int = 70,
9896
):
9997
self.size = size
98+
self.center_crop = center_crop
10099
self.tokenizer = tokenizer
101100
self.resize = resize
102101

@@ -122,32 +121,25 @@ def __init__(
122121
self.class_prompt = class_prompt
123122
else:
124123
self.class_data_root = None
125-
self.h_flip = h_flip
124+
126125
self.image_transforms = transforms.Compose(
127126
[
128127
transforms.Resize(
129128
size, interpolation=transforms.InterpolationMode.BILINEAR
130129
)
131130
if resize
132131
else transforms.Lambda(lambda x: x),
133-
transforms.ColorJitter(0.1, 0.1)
132+
transforms.ColorJitter(0.2, 0.1)
134133
if color_jitter
135134
else transforms.Lambda(lambda x: x),
135+
transforms.RandomHorizontalFlip()
136+
if h_flip
137+
else transforms.Lambda(lambda x: x),
136138
transforms.ToTensor(),
137139
transforms.Normalize([0.5], [0.5]),
138140
]
139141
)
140142

141-
self.use_face_segmentation_condition = use_face_segmentation_condition
142-
if self.use_face_segmentation_condition:
143-
import mediapipe as mp
144-
145-
mp_face_detection = mp.solutions.face_detection
146-
self.face_detection = mp_face_detection.FaceDetection(
147-
model_selection=1, min_detection_confidence=0.5
148-
)
149-
self.blur_amount = blur_amount
150-
151143
def __len__(self):
152144
return self._length
153145

@@ -171,59 +163,6 @@ def __getitem__(self, index):
171163
for token, value in self.token_map.items():
172164
text = text.replace(token, value)
173165

174-
if self.use_face_segmentation_condition:
175-
image = cv2.imread(
176-
str(self.instance_images_path[index % self.num_instance_images])
177-
)
178-
results = self.face_detection.process(
179-
cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
180-
)
181-
black_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
182-
183-
if results.detections:
184-
185-
for detection in results.detections:
186-
187-
x_min = int(
188-
detection.location_data.relative_bounding_box.xmin
189-
* image.shape[1]
190-
)
191-
y_min = int(
192-
detection.location_data.relative_bounding_box.ymin
193-
* image.shape[0]
194-
)
195-
width = int(
196-
detection.location_data.relative_bounding_box.width
197-
* image.shape[1]
198-
)
199-
height = int(
200-
detection.location_data.relative_bounding_box.height
201-
* image.shape[0]
202-
)
203-
204-
# draw the colored rectangle
205-
black_image[y_min : y_min + height, x_min : x_min + width] = 255
206-
207-
# blur the image
208-
black_image = Image.fromarray(black_image, mode="L").filter(
209-
ImageFilter.GaussianBlur(radius=self.blur_amount)
210-
)
211-
# to tensor
212-
black_image = transforms.ToTensor()(black_image)
213-
# resize as the instance image
214-
black_image = transforms.Resize(
215-
self.size, interpolation=transforms.InterpolationMode.BILINEAR
216-
)(black_image)
217-
218-
example["mask"] = black_image
219-
220-
if self.h_flip and random.random() > 0.5:
221-
hflip = transforms.RandomHorizontalFlip(p=1)
222-
223-
example["instance_images"] = hflip(example["instance_images"])
224-
if self.use_face_segmentation_condition:
225-
example["mask"] = hflip(example["mask"])
226-
227166
example["instance_prompt_ids"] = self.tokenizer(
228167
text,
229168
padding="do_not_pad",

lora_diffusion/utils.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)