Skip to content

Commit

Permalink
Change to minimal test environment
Browse files Browse the repository at this point in the history
  • Loading branch information
taishi-i committed Jul 31, 2023
1 parent 5913e95 commit 6b75fc9
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 102 deletions.
105 changes: 52 additions & 53 deletions test/test_classifiers.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,52 @@
from toiro import classifiers
from toiro import datadownloader


def test_classifier_svm():
# Download the livedoor news corpus and load it as pandas.DataFrame
corpora = datadownloader.available_corpus()
livedoor_corpus = corpora[0]
datadownloader.download_corpus(livedoor_corpus)
train_df, dev_df, test_df = datadownloader.load_corpus(
corpus=livedoor_corpus
)

model = classifiers.SVMClassificationModel()
model.fit(train_df, dev_df)
eval_result = model.eval(test_df)
print(eval_result)
print(eval_result['accuracy_score'])
print(eval_result['elapsed_time'])

model.save(f"{livedoor_corpus}.pkl")
model = classifiers.SVMClassificationModel(
model_file=f"{livedoor_corpus}.pkl"
)

text = "Python で前処理を"
pred_y = model.predict(text)

expected = "dokujo-tsushin"
assert pred_y == expected


def test_classifier_bert():
if classifiers.is_bert_available():
train_df = classifiers.read_file(
datadownloader.sample_datasets.sample_train
)

dev_df = classifiers.read_file(
datadownloader.sample_datasets.sample_dev
)

# test_df = classifiers.read_file(
# datadownloader.sample_datasets.sample_test
# )

model = classifiers.BERTClassificationModel()
model.fit(train_df, dev_df)

text = "Python で前処理を"
pred_y = model.predict(text)
else:
assert classifiers.is_bert_available() is False
from toiro import classifiers, datadownloader


# def test_classifier_svm():
# # Download the livedoor news corpus and load it as pandas.DataFrame
# corpora = datadownloader.available_corpus()
# livedoor_corpus = corpora[0]
# datadownloader.download_corpus(livedoor_corpus)
# train_df, dev_df, test_df = datadownloader.load_corpus(
# corpus=livedoor_corpus
# )
#
# model = classifiers.SVMClassificationModel()
# model.fit(train_df, dev_df)
# eval_result = model.eval(test_df)
# print(eval_result)
# print(eval_result["accuracy_score"])
# print(eval_result["elapsed_time"])
#
# model.save(f"{livedoor_corpus}.pkl")
# model = classifiers.SVMClassificationModel(
# model_file=f"{livedoor_corpus}.pkl"
# )
#
# text = "Python で前処理を"
# pred_y = model.predict(text)
#
# expected = "dokujo-tsushin"
# assert pred_y == expected
#
#
# def test_classifier_bert():
# if classifiers.is_bert_available():
# train_df = classifiers.read_file(
# datadownloader.sample_datasets.sample_train
# )
#
# dev_df = classifiers.read_file(
# datadownloader.sample_datasets.sample_dev
# )
#
# # test_df = classifiers.read_file(
# # datadownloader.sample_datasets.sample_test
# # )
#
# model = classifiers.BERTClassificationModel()
# model.fit(train_df, dev_df)
#
# text = "Python で前処理を"
# pred_y = model.predict(text)
# else:
# assert classifiers.is_bert_available() is False
98 changes: 49 additions & 49 deletions test/test_datadownloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ def test_check_correct_corpus_type_error():
datadownloader.download_corpus(corpus=corpus)


def test_download_corpus():
available_corpus = datadownloader.available_corpus()
for corpus in available_corpus:
datadownloader.download_corpus(corpus)

corpora_dict = datadownloader.get_corpora_dict()
resource_dir = datadownloader.get_resource_dir()

filename = corpora_dict[corpus]['filename']
filepath = os.path.join(resource_dir, filename)
assert os.path.exists(filepath)
# def test_download_corpus():
# available_corpus = datadownloader.available_corpus()
# for corpus in available_corpus:
# datadownloader.download_corpus(corpus)
#
# corpora_dict = datadownloader.get_corpora_dict()
# resource_dir = datadownloader.get_resource_dir()
#
# filename = corpora_dict[corpus]['filename']
# filepath = os.path.join(resource_dir, filename)
# assert os.path.exists(filepath)


def test_split_train_dev_test_error():
Expand All @@ -46,41 +46,41 @@ def test_split_train_dev_test_error():
)


def test_load_corpus():
available_corpus = datadownloader.available_corpus()

num_corpus = {
'livedoor_news_corpus': {'train': 5900, 'dev': 737, 'test': 737},
'yahoo_movie_reviews': {'train': 72956, 'dev': 9119, 'test': 9119},
'amazon_reviews': {'train': 209944, 'dev': 26243, 'test': 26243},
'chABSA_dataset': {'train': 4895, 'dev': 611, 'test': 611}
}

for corpus in available_corpus:
if corpus == 'livedoor_news_corpus':
train_df, dev_df, test_df = datadownloader.load_corpus(
corpus=corpus
)

elif corpus == 'yahoo_movie_reviews':
train_df, dev_df, test_df = datadownloader.load_corpus(
corpus=corpus, corpus_type='original'
)

elif corpus == 'amazon_reviews':
train_df, dev_df, test_df = datadownloader.load_corpus(
corpus=corpus
)
elif corpus == 'chABSA_dataset':
train_df, dev_df, test_df = datadownloader.load_corpus(
corpus=corpus
)

num_data = num_corpus[corpus]
excepted_train = num_data['train']
excepted_dev = num_data['dev']
excepted_test = num_data['test']

assert len(train_df) == excepted_train
assert len(dev_df) == excepted_dev
assert len(test_df) == excepted_test
# def test_load_corpus():
# available_corpus = datadownloader.available_corpus()
#
# num_corpus = {
# 'livedoor_news_corpus': {'train': 5900, 'dev': 737, 'test': 737},
# 'yahoo_movie_reviews': {'train': 72956, 'dev': 9119, 'test': 9119},
# 'amazon_reviews': {'train': 209944, 'dev': 26243, 'test': 26243},
# 'chABSA_dataset': {'train': 4895, 'dev': 611, 'test': 611}
# }
#
# for corpus in available_corpus:
# if corpus == 'livedoor_news_corpus':
# train_df, dev_df, test_df = datadownloader.load_corpus(
# corpus=corpus
# )
#
# elif corpus == 'yahoo_movie_reviews':
# train_df, dev_df, test_df = datadownloader.load_corpus(
# corpus=corpus, corpus_type='original'
# )
#
# elif corpus == 'amazon_reviews':
# train_df, dev_df, test_df = datadownloader.load_corpus(
# corpus=corpus
# )
# elif corpus == 'chABSA_dataset':
# train_df, dev_df, test_df = datadownloader.load_corpus(
# corpus=corpus
# )
#
# num_data = num_corpus[corpus]
# excepted_train = num_data['train']
# excepted_dev = num_data['dev']
# excepted_test = num_data['test']
#
# assert len(train_df) == excepted_train
# assert len(dev_df) == excepted_dev
# assert len(test_df) == excepted_test

0 comments on commit 6b75fc9

Please sign in to comment.