diff --git a/vlmeval/dataset/image_mcq.py b/vlmeval/dataset/image_mcq.py index 28c8a024c..a12e36456 100644 --- a/vlmeval/dataset/image_mcq.py +++ b/vlmeval/dataset/image_mcq.py @@ -2,6 +2,7 @@ from .image_base import ImageBaseDataset from .utils import build_judge, DEBUG_MESSAGE +from ..utils import track_progress_rich from ..smp import * import pandas as pd @@ -348,6 +349,64 @@ def build_prompt(self, line): msgs = self.split_MMMU(msgs) return msgs + def evaluate(self, eval_file, **judge_kwargs): + from .utils.multiple_choice import ( + mmmu_evaluation, report_acc + ) + nproc = judge_kwargs.pop('nproc', 4) + suffix = eval_file.split('.')[-1] + model = judge_kwargs.get('model', 'exact_matching') + assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125'] + name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'} + name_str = name_str_map[model] if model in name_str_map else model + result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}') + score_file = eval_file.replace(f'.{suffix}', '_acc.csv') + tmp_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl') + + if model == 'exact_matching': + model = None + elif gpt_key_set(): + model = build_judge(**judge_kwargs) + if not model.working(): + warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation') + warnings.warn(DEBUG_MESSAGE) + model = None + else: + warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation') + model = None + + data = load(eval_file) + lt = len(data) + lines = [data.iloc[i] for i in range(lt)] + tups = [(model, line, self.dataset_name) for line in lines] + indices = [line['index'] for line in lines] + + ans = {} + if osp.exists(tmp_file): + ans = load(tmp_file) + tups = [x for x, i in zip(tups, indices) if i not in ans] + indices = [i for i in indices if i not in ans] + + if len(indices): + _ = track_progress_rich( + mmmu_evaluation, + tups, + nproc=nproc, + chunksize=nproc, + keys=indices, + save=tmp_file, + ) + ans = load(tmp_file) + for key, value in ans.items(): + data.loc[data['index'] == key, 'hit'] = value['hit'] + data.loc[data['index'] == key, 'log'] = value['log'] + dump(data, result_file) + + acc = report_acc(data) + + dump(acc, score_file) + return acc + class MMMUProDataset(MMMUDataset): diff --git a/vlmeval/dataset/utils/multiple_choice.py b/vlmeval/dataset/utils/multiple_choice.py index c52845a3c..7d837c919 100644 --- a/vlmeval/dataset/utils/multiple_choice.py +++ b/vlmeval/dataset/utils/multiple_choice.py @@ -624,3 +624,152 @@ def get_dimension_rating(data_path): results[task]['Avg'] = acc_task results['Overall'] = succ_all / sum_all return results + + +def extract_numbers(string): + """ + Exact all forms of numbers from a string with regex. + """ + # Pattern for numbers with commas + pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b' + # Pattern for scientific notation + pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+' + # Pattern for simple numbers without commas + pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])' + + # Extract numbers with commas + numbers_with_commas = re.findall(pattern_commas, string) + # Extract numbers in scientific notation + numbers_scientific = re.findall(pattern_scientific, string) + # Extract simple numbers without commas + numbers_simple = re.findall(pattern_simple, string) + + # Combine all extracted numbers + all_numbers = numbers_with_commas + numbers_scientific + numbers_simple + return all_numbers + + +def parse_open_response(response): + """ + Parse the prediction from the generated response. + Return a list of predicted strings or numbers. + """ + # content = content.strip("\n").strip(".").strip(" ") + def get_key_subresponses(response): + key_responses = [] + response = response.strip().strip(".").lower() + sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response) + indicators_of_keys = ['could be ', 'so ', 'is ', 'thus ', 'therefore ', 'final ', 'answer ', 'result '] + key_responses = [] + for index, resp in enumerate(sub_responses): + # if last one, accept it's an equation (the entire response can be just one sentence with equation) + if index == len(sub_responses) - 1: + indicators_of_keys.extend(['=']) + # the shortest response that may contain the answer (tail part of the response) + shortest_key_response = None + for indicator in indicators_of_keys: + if indicator in resp: + if not shortest_key_response: + shortest_key_response = resp.split(indicator)[-1].strip() + else: + if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): + shortest_key_response = resp.split(indicator)[-1].strip() + # key_responses.append(resp.split(indicator)[1].strip()) + + if shortest_key_response: + # and it's not trivial + if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]: + key_responses.append(shortest_key_response) + if len(key_responses) == 0: # did not found any + return [response] + return key_responses + # pdb.set_trace() + key_responses = get_key_subresponses(response) + + pred_list = key_responses.copy() # keep the original string response + for resp in key_responses: + pred_list.extend(extract_numbers(resp)) + + tmp_pred_list = [] + for i in range(len(pred_list)): + tmp_pred_list.extend(normalize_str(pred_list[i])) + pred_list = tmp_pred_list + + # remove duplicates + pred_list = list(set(pred_list)) + + return pred_list + + +def check_is_number(string): + """ + Check if the given string a number. + """ + try: + float(string.replace(',', '')) + return True + except ValueError: + # check if there's comma inside + return False + + +def normalize_str(string): + """ + Normalize the str to lower case and make them float numbers if possible. + """ + # check if characters in the string + + # if number, numerize it. + string = string.strip() + + is_number = check_is_number(string) + + if is_number: + string = string.replace(',', '') + string = float(string) + # leave 2 decimal + string = round(string, 2) + return [string] + else: # it's likely to be a string + # lower it + string = string.lower() + if len(string) == 1: + return [" " + string, string + " "] # avoid trivial matches + return [string] + + +def mmmu_evaluation(model, line, dataset_name): + if 'question_type' in line and line['question_type'] == 'open': + hit = 0 + match_log = 'Failed to match' + if isinstance(line['answer'], list): + # use float to avoid trivial matches + norm_answers = [] + for answer in line['answer']: + norm_answers.extend(normalize_str(answer)) + else: + norm_answers = normalize_str(line['answer']) + parsed_pred = parse_open_response(line['prediction']) + for pred in parsed_pred: # pred is already normalized in parse response phase + if isinstance(pred, str): # if it's a string, then find if ans in the pred_i + for norm_ans in norm_answers: + # only see if the string answer in the string pred + if isinstance(norm_ans, str) and norm_ans in pred: + if not hit: + hit = 1 + match_log = 'answer in pred' + break + else: # it's a float number + if pred in norm_answers: + if not hit: + hit = 1 + match_log = 'pred is float, hit the answer' + break + return dict(hit=hit, log=f'Match Log: {match_log}. ') + else: + res = extract_answer_from_item(model, line, dataset_name=dataset_name) + opt, match_log = res['opt'], res['log'] + if opt == line['answer']: + return dict(hit=1, log=f'Match Log: {match_log}. ') + else: + return dict(hit=0, log=f'Match Log: {match_log}. ')