Skip to content

Data_Processing

dlrgy22 edited this page May 25, 2021 · 7 revisions

1️⃣ Preprocess data

1. Why

  • MRC + retrieval 모두에서 크게 필요없다고 생각되는 개행문자("\n", "\\n"), 정답에 포함되어 있지 않는 특수문자, "#"등을 전처리 하고자 하였다.
  • wiki data를 그대로 가져온것이라고 생각 되었고 preprocessing이 필요하다고 판단

2. code

def preprocess(text):
    text = re.sub(r'\n', ' ', text)
    text = re.sub(r"\\n", " ", text)
    text = re.sub(r'#', ' ', text)
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"[^a-zA-Z0-9가-힣ㄱ-ㅎㅏ-ㅣぁ-ゔァ-ヴー々〆〤一-龥<>()\s\.\?!》《≪≫\'<>〈〉:‘’%,『』「」<>・\"-“”∧]", "", text)
    return text

def run_preprocess(data_dict):
    context = data_dict["context"]
    start_ids = data_dict["answers"]["answer_start"][0]
    
    before = data_dict["context"][:start_ids]
    after = data_dict["context"][start_ids:]
    
    process_before = preprocess(before)
    process_after = preprocess(after)
    process_data = process_before + process_after
    
    ids_move = len(before) - len(process_before)
    
    data_dict["context"] = process_data
    data_dict["answers"]["answer_start"][0] = start_ids - ids_move
    
    return data_dict

def check(data_list):
    for data in data_list:
        start_ids = data["answers"]["answer_start"][0]
        end_ids = start_ids + len(data["answers"]["text"][0])
        if data["answers"]["text"][0] != data["context"][start_ids : end_ids]:
            print("wrong")
            return
    print("good")
train_data = load_from_disk("train data path")["train"]
new_train_data = []
for data in train_data:
    new_data = run_preprocess(data)
    new_train_data.append(new_data)
    
check(new_train_data)

preprocess

  • 개행문자 "\n", "\\n" 를 " "으로 대체
  • "#"을 " "으로 대체
  • 개행문자, "#"를 제외하면서 " "이 두번이상 반복될 시 " " 한번으로 대체
  • 영어, 한글, 한문, 일본어 + 정답에 존재하는 특수문자를 제외한 모든 문자들을 " "으로 대체

run preprocess

  • 전처리 + 전처리후 answer의 start index 찾기
  • answer의 start index 기준으로 before, after로 분리하여 preprocess
  • before의 길이와 process_before의 길이를 비교하여 process data의 start index 정의

3. result

Before

after

2️⃣ Concat data

1. Why + requirement

  • 실제 모델 파이프라인에서는 1개의 question에 대하여 retrieval에서 top k개의 context를 선택후 모두 concat 하여 MRC의 input으로 사용, 이 때 train에서 1개의 question에 대해서 1개의 context만을 사용하여 정답을 낼 경우 validation을 신용하기 힘들다.

  • 1개의 question에 대하여 유사도가 높은 context를 concat하여 사용 할 경우 정답이 cls 토큰인 데이터가 augmentation되는 효과

  • Elastic Search를 이용한 retrieval이 필요

    How to use Elastic Search : Elastic Search for Beginners

2. Code

from elasticsearch import Elasticsearch

config = {'host':'localhost', 'port':9200}
es = Elasticsearch([config])

# test connection
es.ping()

def search_es(es_obj, index_name, question_text, n_results):
    query = {
            'query': {
                'match': {
                    'document_text': question_text
                    }
                }
            }
    
    res = es_obj.search(index=index_name, body=query, size=n_results)
    
    return res
# train_qa : train data list
for step, question in enumerate(train_qa):
    k = 5
    res = search_es(es, "wiki_split", question["question"], k)
    context_list = [(hit['_source']['document_text'], hit['_score']) for hit in res['hits']['hits']]
    add_text = train_qa[step]["context"]
    count = 0
    for context in context_list:
        #같은것이 있을 경우 continue 하여 concat X
        if question["context"] == context[0]:
            continue
        add_text += " " + context[0]
        count += 1
        if count == 4:
            break
    train_qa[step]["context"] = add_text

3. result

question과 유사한 context K개를 concat한 데이터 생성

3️⃣ wiki split

1. Why

  • 단순하게 retrieval에서 많은 context를 뽑아 concat하여 MRC모델에게 넘겨주어 MRC에게 많은 부담을 주는 문제 해결

  • MRC의 부담을 줄여주기 위하여 retrieval에 들어가는 context을 문장단위로 split하여 사용

  • trade off였던 retrieval의 성능저하가 크지 않다. (Top 5, 10, 15 기준)

2. code

def passage_split(text):
  	split_length = 400
    num = len(text) // split_length
    count = 1
    split_datas = kss.split_sentences(text)
    data_list = []
    data = ""
    for split_data in split_datas:
        if abs(len(data) - split_length) > abs(len(data) + len(split_data) - split_length) and count < num:
            if len(data) == 0:
                data += split_data
            else:
                data += (" " + split_data)
        elif count < num:
            data_list.append(data)
            count += 1
            data = ""
            data += split_data
        else:
            data += split_data
        
    data_list.append(data)
    return data_list, len(data_list)
with open("/opt/ml/input/data/preprocess_wiki.json", "r") as f:
    wiki = json.load(f)
    
new_wiki = dict()
for i in range(len(wiki)):
    if len(wiki[str(i)]["text"]) < 800:
        new_wiki[str(i)] = wiki[str(i)]
        continue
        
    data_list, count = passage_split(wiki[str(i)]["text"])
    for j in range(count):
        new_wiki[str(i) + f"_{j}"] = {"text" : data_list[j], 
        															"corpus_source" : wiki[str(i)]["corpus_source"], 
        															"url" :  wiki[str(i)]["url"], 
        															"domain" : wiki[str(i)]["domain"], 
        															"title" : wiki[str(i)]["title"], 
        															"author" : wiki[str(i)]["author"], 
        															"html" : wiki[str(i)]["html"], 
        															"document_id" : wiki[str(i)]["document_id"]}

passage_split

  • context의 길이가 일정 이상(800)일 경우 split
  • kss 라이브러리를 사용하여 문장단위로 split
  • 받은 context의 길이에 따라 split되는 문장의 수를 달리한다. (최대한 같은 길이로 split)

4️⃣ question type data

1. why + requirement

  • Query Attention model에 사용할 question type이 정의 되어있는 데이터셋 구축

    Query Attention model

  • AI hub data를 이용하여 Question type을 예측하는 model 학습

    AI hub 기계독해 데이터

    train code

  • Question type class : 'work_how':0, 'work_what':1, 'work_when':2, 'work_where':3, 'work_who':4, 'work_why':5

2. code

# model : question type classification model (train by AI hub data)
def question_labeling(model, train_iter, val_iter):
    train_file = get_pickle("../../data/concat_train.pkl")["train"]
    validation_file = get_pickle("../../data/concat_train.pkl")["validation"]

    train_qa = [{"id" : train_file[i]["id"],
                 "question" : train_file[i]["question"], 
                 "answers" : train_file[i]["answers"], 
                 "context" : train_file[i]["context"]} for i in range(len(train_file))]
    validation_qa = [{"id" : validation_file[i]["id"], 
                      "question" : validation_file[i]["question"], 
                      "answers" : validation_file[i]["answers"], 
                      "context" : validation_file[i]["context"]} for i in range(len(validation_file))]
    
    device = "cuda:0"
    for step, (input_ids, attention_mask, labels) in tqdm(enumerate(train_iter), total=len(train_iter), position=0, leave=True):
        score = model(input_ids.to(device), attention_mask=attention_mask.to(device))[0]
        pred = torch.argmax(score, 1).detach().cpu().numpy()
        train_qa[step]["question_type"] = pred
    
    for step, (input_ids, attention_mask, labels) in tqdm(enumerate(val_iter), total=len(val_iter), position=0, leave=True):
        score = model(input_ids.to(device), attention_mask=attention_mask.to(device))[0]
        pred = torch.argmax(score, 1).detach().cpu().numpy()
        validation_qa[step]["question_type"] = pred
    
    train_df = pd.DataFrame(train_qa)
    val_df = pd.DataFrame(validation_qa)
        
    return train_df, val_df
def save_data(train_df, val_df):
    train_f = Features({'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None),
                        'context': Value(dtype='string', id=None),
                        'id': Value(dtype='string', id=None),
                        'question': Value(dtype='string', id=None),
                        'question_type' : Value(dtype='int32', id=None)})
    
    train_datasets = DatasetDict({'train': Dataset.from_pandas(train_df, features=train_f), 'validation': Dataset.from_pandas(val_df, features=train_f)})
    file = open("../../data/question_type.pkl", "wb")
    pickle.dump(train_datasets, file)
    file.close()

question_labeling

AI hub로 사전에 학습된 모델을 이용하여 대회 train data에 대하여 question type을 추가