1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
| def train_lora( pretrained_model="CompVis/stable-diffusion-v1-4", data_dir="./training_data", output_dir="./lora_output", rank=4, learning_rate=1e-4, max_train_steps=1000, batch_size=1, gradient_accumulation_steps=4 ): pipe = StableDiffusionPipeline.from_pretrained( pretrained_model, torch_dtype=torch.float16 ).to("cuda") unet = pipe.unet vae = pipe.vae text_encoder = pipe.text_encoder unet, lora_layers = apply_lora_to_unet(unet, rank=rank) optimizer = torch.optim.AdamW( [p for layer in lora_layers for p in [layer.lora_A, layer.lora_B]], lr=learning_rate ) noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config) dataset = LoRADataset(data_dir) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) global_step = 0 unet.train() while global_step < max_train_steps: for batch in dataloader: with torch.no_grad(): latents = vae.encode(batch["pixel_values"].to("cuda")) latents = latents * 0.18215 with torch.no_grad(): text_inputs = pipe.tokenizer( batch["caption"], padding=True, truncation=True, return_tensors="pt" ).to("cuda") text_embeddings = text_encoder(**text_inputs)[0] noise = torch.randn_like(latents) timesteps = torch.randint( 0, noise_scheduler.num_train_timesteps, (latents.shape[0],), device="cuda" ) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = unet(noisy_latents, timesteps, text_embeddings).sample loss = F.mse_loss(noise_pred, noise) loss.backward() if (global_step + 1) % gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 if global_step % 100 == 0: print(f"Step {global_step}, Loss: {loss.item():.4f}") if global_step >= max_train_steps: break save_lora_weights(lora_layers, output_dir) print(f"LoRA权重已保存到 {output_dir}")
|