-
Notifications
You must be signed in to change notification settings - Fork 201
/
Copy pathrelation_identity_mapper.py
156 lines (132 loc) · 6.04 KB
/
relation_identity_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import re
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'relation_identity_mapper'
# TODO: LLM-based inference.
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class RelationIdentityMapper(Mapper):
"""
identify relation between two entity in the text.
"""
DEFAULT_SYSTEM_PROMPT_TEMPLATE = (
'给定关于{entity1}和{entity2}的文本信息。'
'判断{entity1}和{entity2}之间的关系。\n'
'要求:\n'
'- 关系用一个或多个词语表示,必要时可以加一个形容词来描述这段关系\n'
'- 输出关系时不要参杂任何标点符号\n'
'- 需要你进行合理的推理才能得出结论\n'
'- 如果两个人物身份是同一个人,输出关系为:另一个身份\n'
'- 输出格式为:\n'
'分析推理:...\n'
'所以{entity2}是{entity1}的:...\n'
'- 注意输出的是{entity2}是{entity1}的什么关系,而不是{entity1}是{entity2}的什么关系')
DEFAULT_INPUT_TEMPLATE = '关于{entity1}和{entity2}的文本信息:\n```\n{text}\n```\n'
DEFAULT_OUTPUT_PATTERN_TEMPLATE = r"""
\s*分析推理:\s*(.*?)\s*
\s*所以{entity2}是{entity1}的:\s*(.*?)\Z
"""
def __init__(self,
api_model: str = 'gpt-4o',
source_entity: str = None,
target_entity: str = None,
*,
output_key: str = MetaKeys.role_relation,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt_template: Optional[str] = None,
input_template: Optional[str] = None,
output_pattern_template: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param source_entity: The source entity of the relation to be
identified.
:param target_entity: The target entity of the relation to be
identified.
:param output_key: The output key in the meta field in the
samples. It is 'role_relation' in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt_template: System prompt template for the task.
:param input_template: Template for building the model input.
:param output_pattern_template: Regular expression template for
parsing model output.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
if source_entity is None or target_entity is None:
logger.warning('source_entity and target_entity cannot be None')
self.source_entity = source_entity
self.target_entity = target_entity
self.output_key = output_key
system_prompt_template = system_prompt_template or \
self.DEFAULT_SYSTEM_PROMPT_TEMPLATE
self.system_prompt = system_prompt_template.format(
entity1=source_entity, entity2=target_entity)
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
output_pattern_template = output_pattern_template or \
self.DEFAULT_OUTPUT_PATTERN_TEMPLATE
self.output_pattern = output_pattern_template.format(
entity1=source_entity, entity2=target_entity)
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.drop_text = drop_text
def parse_output(self, raw_output):
pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
relation = ''
for match in matches:
_, relation = match
relation = relation.strip()
return relation
def process_single(self, sample, rank=None):
meta = sample[Fields.meta]
if self.output_key in meta:
return sample
client = get_model(self.model_key, rank=rank)
text = sample[self.text_key]
input_prompt = self.input_template.format(entity1=self.source_entity,
entity2=self.target_entity,
text=text)
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': input_prompt
}]
relation = ''
for i in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
relation = self.parse_output(output)
if len(relation) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
meta[self.output_key] = relation
if self.drop_text:
sample.pop(self.text_key)
return sample