-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
132 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Guide to Adapting P4D for Different Concepts and Safe T2I Models | ||
|
||
This guide explains how to adapt our P4D framework to work with different target concepts or safe T2I models. | ||
|
||
## Adapting to New Target Concepts | ||
|
||
### 1. Data Processing (`process_data.py`) | ||
|
||
To adapt P4D for new target concepts, you'll need to modify how the framework processes and evaluates data: | ||
- Modify the `load_dataset()` function to: include your target concepts of interest, implement appropriate filtering for your raw data | ||
- modify the `Eval` class to include a custom evaluator for your target concept | ||
- For object categories: Consider implementing [Receler](https://github.com/jasper0314-huang/Receler)'s approach using GroundingDINO | ||
- For other concepts: Implement any suitable detection/evaluation method | ||
- The key requirement is a binary output indicating whether the generated image contains the target concept | ||
|
||
### 2. Running P4D (`run_p4dn.py`, `run_p4dk.py`) | ||
|
||
The main modification needed is in the `Eval` class: | ||
|
||
```python | ||
class Eval: | ||
def __init__(self): | ||
# Load your concept evaluator | ||
self.evaluator = YourConceptEvaluator() | ||
|
||
def __call__(self, image): | ||
# Input: generated image | ||
# Output: (contains_concept, unsafe_percentage) | ||
result = self.evaluator.evaluate(image) | ||
return result.contains_concept, result.unsafe_percentage | ||
|
||
def get_unsafe(self, images): | ||
# Input: list of images | ||
# Output: image with highest unsafe percentage | ||
scores = [self.evaluator.evaluate(img).confidence for img in images] | ||
return images[np.argmax(scores)] | ||
``` | ||
|
||
## Adapting to New Safe T2I Models | ||
|
||
P4D is a white-box method requiring access to model internals. Here's how to adapt it for different models: | ||
|
||
### For SD-Based Models | ||
|
||
If your target model is SD-based (uses `StableDiffusionPipeline`): | ||
1. Use our `ModifiedStableDiffusionPipeline` from `model.p4dn(k).modified_stable_diffusion_pipeline` | ||
2. Load your custom safety components (e.g., checkpoint or some safety modules (python code)) | ||
|
||
### For Other Model Architectures | ||
|
||
You'll need source code access and make the following modifications: | ||
|
||
#### 1. Add Required Functions | ||
Codes in `models/` as example | ||
```python | ||
def _new_encode_prompt(self, prompt): | ||
# Encode prompt to text encoder hidden states | ||
# This should be the first step of your T2I model's forward pass | ||
# Implement this as a standalone function | ||
pass | ||
|
||
def _get_text_embeddings_with_embeddings(self, dummy_embeddings): | ||
# Convert initialized adversarial prompt embeddings | ||
# to your model's text embeddings format | ||
pass | ||
|
||
def _expand_safety_text_embeddings(self, embeddings): | ||
# If your model uses safety/negative prompts: | ||
# Concatenate them with the main text embeddings | ||
pass | ||
``` | ||
|
||
#### 2. Modify Optimization Process | ||
In `optimize_n.py` and `optimize_k.py`: | ||
|
||
1. Update the `optimize()` function: | ||
```python | ||
def optimize(self): | ||
# Encode input prompt | ||
text_embeddings = self._new_encode_prompt(prompt) | ||
|
||
# Encode adversarial embeddings | ||
adv_embeddings = self._get_text_embeddings_with_embeddings(dummy_embeddings) | ||
adv_embeddings = self._expand_safety_text_embeddings(adv_embeddings) | ||
|
||
# Forward pass for unconstrained T2I -> until noise prediction | ||
noise_pred_unconstr = self.forward_pass(text_embeddings) | ||
|
||
# Forward pass for safe T2I -> until noise prediction | ||
noise_pred_safe = self.forward_pass(adv_embeddings) | ||
|
||
# Calculate MSE loss | ||
loss = F.mse_loss(noise_pred_safe, noise_pred_unconstr) | ||
|
||
# Backpropagate | ||
loss.backward() | ||
``` | ||
|
||
### Important Implementation Notes | ||
|
||
1. **Memory Management**: For computationally intensive models: | ||
- Distribute components across multiple GPUs | ||
- Default setup: safe T2I and unconstrained T2I on separate GPUs | ||
- Watch for memory leaks when using multiple devices | ||
|
||
2. **Optimization Loop**: | ||
- Every 50 iterations: Generate images using current adversarial prompts | ||
- Compare with target images (generated by unconstrained T2I) | ||
- Update best adversarial prompt based on image similarity | ||
|
||
3. **API Limitations**: | ||
- P4D cannot be used with API-only models (e.g., DALL·E 3) | ||
- Source code access is required for proper implementation | ||
|
||
Remember to thoroughly test your modifications and monitor for unexpected behaviors, especially when dealing with memory management across multiple devices. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters