Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Misclassification (All Images Classified as NSFW) #59

Open
mehranfvs opened this issue Aug 18, 2024 · 0 comments
Open

Model Misclassification (All Images Classified as NSFW) #59

mehranfvs opened this issue Aug 18, 2024 · 0 comments

Comments

@mehranfvs
Copy link

Description :

I've been using the "MobileVLMv2-1.7B" model for a task where I need to classify images as either "NSFW" or "SFW".
However, the model consistently classifies all images as NSFW, regardless of their actual content.
This behavior persists even with various prompt modifications and different image inputs.

Reproduce:
prompt = "Is this picture sfw or nsfw?\nAnswer the question using a single word of nsfw or sfw"

Expected Behavior:
The model should classify images accurately as SFW or NSFW based on the content.

Actual Behavior:
All images are classified as NSFW, even those that are clearly SFW.

I made some changes to the script
including encapsulating functionalities within a class and ...
However, the overall workflow and logic of the code remain consistent with the original version.
https://github.com/Meituan-AutoML/MobileVLM/blob/main/scripts/inference.py

my code :

  import sys
  import os
  import torch
  import argparse
  from PIL import Image
  from pathlib import Path
  from time import time
  import transformers
  import warnings
  
  transformers.logging.set_verbosity_error()
  warnings.filterwarnings('ignore')
  
  sys.path.append(str(Path(__file__).parent.parent.resolve()))
  from mobilevlm.model.mobilevlm import load_pretrained_model
  from mobilevlm.conversation import conv_templates, SeparatorStyle
  from mobilevlm.utils import disable_torch_init, process_images, tokenizer_image_token, KeywordsStoppingCriteria
  from mobilevlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
  
  
  
  class inference_once():
      
      def __init__(self):
          pass
      
      def load_model(self,model_path,load_8bit=False,load_4bit=False):
          model_name = model_path.split('/')[-1]
          self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(model_path,
                                                                                                load_8bit,
                                                                                                load_4bit)
          disable_torch_init()
          
      def load_image(self,image_file):
          images = [(Image.open(image_file).convert("RGB")).resize((512, 512), Image.ANTIALIAS)]
          self.images_tensor = process_images(images,
                                         self.image_processor,
                                         self.model.config).to(self.model.device,
                                                              dtype=torch.float16)
  
      def infer(self,args):
          conv = conv_templates[args.conv_mode].copy()
          conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + args.prompt)
          conv.append_message(conv.roles[1], None)
          
          prompt = conv.get_prompt()
          stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
          
          # Input
          self.input_ids = (tokenizer_image_token(prompt,
                                             self.tokenizer,
                                             IMAGE_TOKEN_INDEX,
                                             return_tensors="pt").unsqueeze(0).cuda())
          
          self.stopping_criteria = KeywordsStoppingCriteria([stop_str],
                                                       self.tokenizer,
                                                       self.input_ids)
          
          # Inference
          with torch.inference_mode():
              output_ids = self.model.generate(
                  self.input_ids,
                  images= self.images_tensor,
                  do_sample= True if args.temperature > 0 else False,
                  temperature= args.temperature,
                  top_p= args.top_p,
                  num_beams= args.num_beams,
                  max_new_tokens= args.max_new_tokens,
                  use_cache=True,
                  stopping_criteria=[self.stopping_criteria])
  
          # Result-Decode
          input_token_len = self.input_ids.shape[1]
          n_diff_input_output = (self.input_ids != output_ids[:, :input_token_len]).sum().item()
          
          if n_diff_input_output > 0:
              print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
          
          outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
                                                  skip_special_tokens=True)[0]
          outputs = outputs.strip()
          if outputs.endswith(stop_str):
              outputs = outputs[: -len(stop_str)]
          # print(f"🚀 {model_name}: {outputs.strip()}\n")
          return outputs
  
  if __name__ == "__main__":
      model_path = "/MobileVLM_V2-1.7B"
      image_folder = "/samples"
      prompt = "Is this picture sfw or nsfw?\nAnswer the question using a single word of nsfw or sfw"
      
      mvlm2 = inference_once()
      s_l = time()
      mvlm2.load_model(model_path)
      e_l = time()- s_l
      print(f"load time : {e_l:.2f} seconds")
      
      images = os.listdir(image_folder)
      for img in images:
          img_path = os.path.join(image_folder, img)
          print(img)
          
          s = time()
          
          args = type('Args', (), {
          "prompt": prompt,
          "conv_mode": "v1",
          "temperature": 0, 
          "top_p": None,
          "num_beams": 1,
          "max_new_tokens": 512
          })()
          
          mvlm2.load_image(img_path)
          output = mvlm2.infer(args)
          
          e = time()-s
          
          print("\033[92m", 50 * "-")
          print(f"time : {e:.2f} seconds")
          print(f"🚀 output {img} : {output.strip()}")
          print(50 * "-", "\033[0m")

output :

Screenshot from 2024-08-18 12-18-29

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant