From 37058a97e3b6c81137bd7b57db969a57a1715d32 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 16 Oct 2024 15:41:49 -0700 Subject: [PATCH] Change NSFW Model (#307) * Change download for NSFW model Signed-off-by: Ryan Wolf * Fix model init Signed-off-by: Ryan Wolf * Fix embedding size Signed-off-by: Ryan Wolf --------- Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/nsfw.py | 64 ++++++++++++++------------ 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py index e66abcaef..16aecf818 100644 --- a/nemo_curator/image/classifiers/nsfw.py +++ b/nemo_curator/image/classifiers/nsfw.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import zipfile from typing import Optional import requests @@ -23,33 +24,35 @@ # MLP code taken from LAION's CLIP-based-NSFW-Detector -# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/h14_nsfw_model.py -class H14_NSFW_Detector(nn.Module): - def __init__(self, input_size=1024): +# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/issues/7 +class Normalization(nn.Module): + def __init__(self, shape): super().__init__() - self.input_size = input_size - self.layers = nn.Sequential( - nn.Linear(self.input_size, 1024), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(1024, 2048), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(2048, 1024), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(1024, 256), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(256, 128), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(128, 16), - nn.Linear(16, 1), - ) + self.register_buffer("mean", torch.zeros(shape)) + self.register_buffer("variance", torch.ones(shape)) + + def forward(self, x): + return (x - self.mean) / self.variance.sqrt() + + +class NSFWModel(nn.Module): + def __init__(self): + super().__init__() + self.norm = Normalization([768]) + self.linear_1 = nn.Linear(768, 64) + self.linear_2 = nn.Linear(64, 512) + self.linear_3 = nn.Linear(512, 256) + self.linear_4 = nn.Linear(256, 1) + self.act = nn.ReLU() + self.act_out = nn.Sigmoid() def forward(self, x): - return self.layers(x) + x = self.norm(x) + x = self.act(self.linear_1(x)) + x = self.act(self.linear_2(x)) + x = self.act(self.linear_3(x)) + x = self.act_out(self.linear_4(x)) + return x class NsfwClassifier(ImageClassifier): @@ -66,7 +69,7 @@ def __init__( pred_column=pred_column, pred_type=float, batch_size=batch_size, - embedding_size=1024, + embedding_size=768, ) if model_path is None: @@ -76,21 +79,24 @@ def __init__( @staticmethod def _get_default_model(): - weights_name = "h14_nsfw.pth" + weights_name = "clip_autokeras_binary_nsfw.pth" model_path = os.path.join(NEMO_CURATOR_HOME, weights_name) os.makedirs(NEMO_CURATOR_HOME, exist_ok=True) if not os.path.exists(model_path): - url = f"https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/{weights_name}?raw=true" + url = "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/files/10250461/clip_autokeras_binary_nsfw.zip" r = requests.get(url) - with open(model_path, "wb") as f: + raw_zip_path = os.path.join(NEMO_CURATOR_HOME, "nsfw.zip") + with open(raw_zip_path, "wb") as f: f.write(r.content) + with zipfile.ZipFile(raw_zip_path, "r") as f: + f.extractall(NEMO_CURATOR_HOME) return model_path def load_model(self, device): - model = H14_NSFW_Detector(input_size=self.embedding_size).to(device) + model = NSFWModel().to(device) weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval()