Skip to content

Commit eb3122e

Browse files
committed
add test set
1 parent 0cd0a36 commit eb3122e

25 files changed

+384
-40
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ data/val/*
55
!data/val/scripts
66
!data/val/README-v2.md
77

8+
data/test/*
9+
!data/test/README-v1.md
10+
!data/test/scripts
11+
812
logs/*
913

1014
result/*

README.md

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
</div>
66
<br>
77

8-
## Task & Dataset Info
8+
## Task & Dataset Info.
99
[SemEval-2025 Task-3 — Mu-SHROOM](https://helsinki-nlp.github.io/shroom/)
1010

1111

@@ -27,18 +27,37 @@ Download Mu-SHROOM Dataset from [Official Website](https://helsinki-nlp.github.i
2727
sh scripts/preprocess_wiki.sh
2828
```
2929

30-
### Experiment
30+
### Experiment
31+
#### Validation Set
3132
```bash
33+
# Retrieve Contexts
34+
sh scripts/run_val_retriever.sh
35+
36+
# Our Method
37+
sh scripts/run_val_REFIND.sh
38+
39+
# Baselines
40+
sh scripts/run_val_XLM-R.sh
41+
sh scripts/run_val_FAVA.sh
42+
43+
## Evaluation
44+
sh scripts/evaluate_val.sh
45+
```
46+
47+
#### Test Set
48+
```bash
49+
# Retrieve Contexts
50+
sh scripts/run_test_retriever.sh
51+
3252
# Our Method
33-
sh scripts/run_REFIND.sh
53+
sh scripts/run_test_REFIND.sh
3454

3555
# Baselines
36-
sh scripts/run_random_guess.sh
37-
sh scripts/run_XLM-R.sh
38-
sh scripts/run_FAVA.sh
56+
sh scripts/run_test_XLM-R.sh
57+
sh scripts/run_test_FAVA.sh
3958

4059
## Evaluation
41-
sh scripts/evaluate.sh
60+
sh scripts/evaluate_test.sh
4261
```
4362

4463
## References

config/ca_config.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
REFIND:
2+
retriever: HybridRetriever
3+
input_prompt_template: REFIND_PROMPT_TEMPLATE
4+
threshold_list: [0.1, 0.2, 0.3, 0.4]
5+
FAVA:
6+
retriever: HybridRetriever
7+
8+
Retriever:
9+
language: CA
10+
input_file_path: retriever/ca_wiki_corpus.jsonl
11+
parameters:
12+
retrieval_chunk_size: 600
13+
retrieval_chunk_overlap: 30
14+
retrieval_top_k: 5
15+
HybridRetriever:
16+
language: CA
17+
input_file_path: retriever/ca_wiki_corpus.jsonl
18+
embedding_model_path: intfloat/multilingual-e5-large
19+
parameters:
20+
retrieval_chunk_size: 600
21+
retrieval_chunk_overlap: 30
22+
retrieval_top_k: 10
23+
reranking_top_k: 5

config/cs_config.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
REFIND:
2+
retriever: HybridRetriever
3+
input_prompt_template: REFIND_PROMPT_TEMPLATE
4+
threshold_list: [0.1, 0.2, 0.3, 0.4]
5+
FAVA:
6+
retriever: HybridRetriever
7+
8+
Retriever:
9+
language: CS
10+
input_file_path: retriever/cs_wiki_corpus.jsonl
11+
parameters:
12+
retrieval_chunk_size: 600
13+
retrieval_chunk_overlap: 30
14+
retrieval_top_k: 5
15+
HybridRetriever:
16+
language: CS
17+
input_file_path: retriever/cs_wiki_corpus.jsonl
18+
embedding_model_path: intfloat/multilingual-e5-large
19+
parameters:
20+
retrieval_chunk_size: 600
21+
retrieval_chunk_overlap: 30
22+
retrieval_top_k: 10
23+
reranking_top_k: 5

config/eu_config.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
REFIND:
2+
retriever: HybridRetriever
3+
input_prompt_template: REFIND_PROMPT_TEMPLATE
4+
threshold_list: [0.1, 0.2, 0.3, 0.4]
5+
FAVA:
6+
retriever: HybridRetriever
7+
8+
Retriever:
9+
language: EU
10+
input_file_path: retriever/eu_wiki_corpus.jsonl
11+
parameters:
12+
retrieval_chunk_size: 600
13+
retrieval_chunk_overlap: 30
14+
retrieval_top_k: 5
15+
HybridRetriever:
16+
language: EU
17+
input_file_path: retriever/eu_wiki_corpus.jsonl
18+
embedding_model_path: intfloat/multilingual-e5-large
19+
parameters:
20+
retrieval_chunk_size: 600
21+
retrieval_chunk_overlap: 30
22+
retrieval_top_k: 10
23+
reranking_top_k: 5

config/fa_config.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
REFIND:
2+
retriever: HybridRetriever
3+
input_prompt_template: REFIND_PROMPT_TEMPLATE
4+
threshold_list: [0.1, 0.2, 0.3, 0.4]
5+
FAVA:
6+
retriever: HybridRetriever
7+
8+
Retriever:
9+
language: FA
10+
input_file_path: retriever/fa_wiki_corpus.jsonl
11+
parameters:
12+
retrieval_chunk_size: 600
13+
retrieval_chunk_overlap: 30
14+
retrieval_top_k: 5
15+
HybridRetriever:
16+
language: FA
17+
input_file_path: retriever/fa_wiki_corpus.jsonl
18+
embedding_model_path: intfloat/multilingual-e5-large
19+
parameters:
20+
retrieval_chunk_size: 600
21+
retrieval_chunk_overlap: 30
22+
retrieval_top_k: 10
23+
reranking_top_k: 5
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse as ap
2+
import os
3+
import sys
4+
5+
parent_dir = os.path.dirname(
6+
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7+
)
8+
sys.path.append(parent_dir)
9+
10+
import yaml
11+
from lib import load_jsonl_file, write_jsonl
12+
from retriever.retriever import HybridRetriever
13+
from tqdm import tqdm
14+
15+
16+
p = ap.ArgumentParser()
17+
p.add_argument("--yaml_filepath", type=str, default="config/en_config.yaml")
18+
p.add_argument("--input_filepath", type=str)
19+
args = p.parse_args()
20+
21+
22+
def main():
23+
records = load_jsonl_file(args.input_filepath)
24+
25+
with open(args.yaml_filepath, "r") as f:
26+
config = yaml.load(f, Loader=yaml.FullLoader)
27+
28+
retriever = HybridRetriever(args.yaml_filepath)
29+
30+
records_with_contexts = []
31+
for record in tqdm(records, desc="Retrieving contexts"):
32+
context_list = retriever.retrieve(
33+
query=record["model_input"], return_type="list"
34+
)
35+
assert (
36+
context_list is not None
37+
), f"Failed to retrieve contexts for record {record}"
38+
39+
record["context"] = context_list
40+
records_with_contexts.append(record)
41+
42+
input_directory = os.path.dirname(args.input_filepath)
43+
output_filename = f"context-{os.path.basename(args.input_filepath)}"
44+
output_filepath = os.path.join(input_directory, output_filename)
45+
46+
write_jsonl(records_with_contexts, output_filepath)
47+
print(f"Contexts written to {output_filepath}")
48+
49+
50+
if __name__ == "__main__":
51+
main()

lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ def load_jsonl_file(filename):
3535
Performs minor format checks (ensures that soft_labels are present, optionally compute hard_labels on the fly)."""
3636
df = pd.read_json(filename, lines=True)
3737
if 'hard_labels' not in df.columns:
38-
df['hard_labels'] = df.soft_labels.apply(recompute_hard_labels)
38+
try:
39+
df['hard_labels'] = df.soft_labels.apply(recompute_hard_labels)
40+
except AttributeError:
41+
pass
3942
# adding an extra column for convenience
4043
df['text_len'] = df.model_output_text.apply(len)
4144
return df.to_dict(orient='records')

model/FAVA.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def _find_hallucination_spans(output_text):
4343
spans = []
4444
for match in matches:
4545
content = match.group(2).strip()
46+
if content == "":
47+
continue
4648
start_idx = processed_output_text.find(content)
4749
if start_idx != -1:
4850
end_idx = start_idx + len(content)
@@ -162,6 +164,8 @@ def predict_hallucinations(
162164
exit()
163165

164166
model_output_start_idx = hallucinated_output.find(hallucinated_text)
167+
if model_output_start_idx == -1:
168+
continue
165169
model_output_end_idx = model_output_start_idx + len(hallucinated_text)
166170

167171
hard_labels.append([model_output_start_idx, model_output_end_idx])

model/REFIND.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,15 @@ def _generate_offset_mapping_manually(text, tokenizer):
7878
offset_mapping = []
7979
start = 0
8080
for token in tokens:
81-
token = token.replace("▁", " ")
8281
start = text.find(token, start)
82+
if start == -1:
83+
token = token.replace("▁", " ")
84+
start = text.find(token, start)
85+
if start == -1:
86+
token = token.replace(" ", "")
87+
start = text.find(token, start)
88+
if start == -1:
89+
continue
8390
end = start + len(token)
8491
offset_mapping.append((start, end))
8592
start = end
@@ -228,6 +235,12 @@ def main():
228235
model = AutoModelForCausalLM.from_pretrained(
229236
model_id, device_map="auto", torch_dtype=torch.bfloat16
230237
)
238+
elif model_id.replace("\/", "/") == "CohereForAI/aya-23-35B":
239+
from torch.nn import DataParallel
240+
tokenizer = AutoTokenizer.from_pretrained(model_id.replace("\/", "/"), trust_remote_code=True)
241+
model = AutoModelForCausalLM.from_pretrained(model_id.replace("\/", "/"), trust_remote_code=True)
242+
model = DataParallel(model)
243+
model.to(args.device)
231244
else:
232245
tokenizer = AutoTokenizer.from_pretrained(
233246
model_id.replace("\/", "/"), trust_remote_code=True
@@ -276,12 +289,44 @@ def main():
276289
model_output_token_ids, offsets_mapping = (
277290
_get_tokens_ids_and_offsets_mapping(tokenizer, model_output_text)
278291
)
279-
assert offsets_mapping[-1][1] == len(
280-
model_output_text
281-
), "Offsets mapping and model output text mismatch!"
282-
assert len(model_output_token_ids) == len(
283-
offsets_mapping
284-
), f"Token IDs and offsets mapping mismatch! {len(model_output_token_ids)} vs {len(offsets_mapping)}"
292+
try:
293+
assert offsets_mapping[-1][1] == len(
294+
model_output_text
295+
), "Offsets mapping and model output text mismatch!"
296+
except AssertionError as e:
297+
print(f"AssertionError: {e}")
298+
print(f"offsets_mapping: {offsets_mapping}")
299+
print(f"model_output_text: {model_output_text}")
300+
301+
# Augment offsets_mapping
302+
prev_end_idx = 0
303+
# end_idx = len(model_output_text)
304+
for i, span in enumerate(offsets_mapping):
305+
start_idx, end_idx = span
306+
if start_idx == prev_end_idx:
307+
prev_end_idx = end_idx
308+
continue
309+
else:
310+
offsets_mapping.insert(i, (prev_end_idx, start_idx))
311+
prev_end_idx = end_idx
312+
313+
# Check again
314+
assert offsets_mapping[-1][1] == len(
315+
model_output_text
316+
), "Offsets mapping and model output text mismatch!"
317+
318+
try:
319+
assert len(model_output_token_ids) == len(
320+
offsets_mapping
321+
), f"Token IDs and offsets mapping mismatch! {len(model_output_token_ids)} vs {len(offsets_mapping)}"
322+
except AssertionError as e:
323+
print(f"AssertionError: {e}")
324+
print(f"model_output_token_ids: {model_output_token_ids}")
325+
print(f"offsets_mapping: {offsets_mapping}")
326+
if len(model_output_token_ids) > len(offsets_mapping):
327+
model_output_token_ids = model_output_token_ids[: len(offsets_mapping)]
328+
else:
329+
offsets_mapping = offsets_mapping[: len(model_output_token_ids)]
285330
model_output_probs, model_output_logits = compute_output_probs(
286331
model,
287332
tokenizer,
@@ -457,6 +502,26 @@ def main():
457502
for condition in HALLUCINATION_CONDITIONS.keys():
458503
for threshold in threshold_list:
459504
predictions = total_predictions[condition][threshold]
505+
506+
# Check Validity of prediction
507+
for prediction in predictions:
508+
hard_labels = prediction["hard_labels"]
509+
soft_labels = prediction["soft_labels"]
510+
assert len(hard_labels) == len(soft_labels), "Hard and soft labels mismatch!"
511+
512+
for hard_label, soft_label in zip(hard_labels, soft_labels):
513+
if hard_label[0] < 0 or hard_label[1] > len(prediction["model_output_text"]):
514+
# remove invalid spans
515+
hard_labels.remove(hard_label)
516+
soft_labels.remove(soft_label)
517+
continue
518+
519+
assert (
520+
hard_label[0] == soft_label["start"]
521+
), "Hard and soft labels mismatch!"
522+
assert hard_label[1] == soft_label["end"], "Hard and soft labels mismatch!"
523+
524+
460525
output_file_directory = os.path.join(
461526
args.output_directory,
462527
f'{os.path.basename(args.yaml_filepath.replace(".yaml", ""))}',

0 commit comments

Comments
 (0)