Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
birkirarndal committed Dec 11, 2023
1 parent b279530 commit 95c287a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
2 changes: 0 additions & 2 deletions src/BaselineClassifiersBinary.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@
" print(lr_pipeline)\n",
" print(classification_report(y_test, predict_lr, digits=4))\n",
"\n",
" \n",
" return (\n",
" (\n",
" {\n",
Expand Down Expand Up @@ -396,7 +395,6 @@
" plt.show()\n",
"\n",
"\n",
"\n",
"data2, nb_mideind, svc_mideind, lr_mideind = classify(ICELANDIC_MIDEIND_CSV)\n",
"data3, nb_google, svc_google, lr_google = classify(ICELANDIC_GOOGLE_CSV)\n",
"data1, nb_english, svc_english, lr_english = classify(ENGLISH_CSV)\n",
Expand Down
7 changes: 4 additions & 3 deletions src/generate_classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,13 @@ def generate_report(self, accuracy):
prediction = torch.max(outputs.logits, dim=1)
y_true.extend(labels.tolist())
y_pred.extend(prediction.indices.tolist())

if accuracy:
acc = accuracy_score(y_true, y_pred)
return acc
return classification_report(y_true, y_pred, output_dict=True) # NOTE: can use this if you want to print classification report
return classification_report(
y_true, y_pred, output_dict=True
) # NOTE: can use this if you want to print classification report


class DataFrameLoader:
Expand Down Expand Up @@ -165,7 +167,6 @@ def generate_report(filename, folder, device):
print("Loading model from folder {} using file {}".format(folder, filename))
dfl = DataFrameLoader(filename)
return call_model(dfl.X_all, dfl.y_all, folder, device)



def eval_files():
Expand Down
40 changes: 25 additions & 15 deletions src/generate_report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,38 +64,48 @@
"import random\n",
"import generate_classification_report as gcr\n",
"\n",
"PATH1 = ''\n",
"PATH2 = ''\n",
"PATH1 = \"\"\n",
"PATH2 = \"\"\n",
"\n",
"d1 = pd.read_csv(PATH1)\n",
"d2 = pd.read_csv(PATH2)\n",
"d1.drop(['num', 'rating', 'id'], axis=1, inplace=True)\n",
"d2.drop(['movie', 'rating'], axis=1, inplace=True)\n",
"d1.drop([\"num\", \"rating\", \"id\"], axis=1, inplace=True)\n",
"d2.drop([\"movie\", \"rating\"], axis=1, inplace=True)\n",
"\n",
"\n",
"df_orig = pd.merge(d1, d2, how='outer')\n",
"df_orig = pd.merge(d1, d2, how=\"outer\")\n",
"\n",
"device = 'cuda'\n",
"model = './electra-base-google-batch8-remove-noise-model/'\n",
"device = \"cuda\"\n",
"model = \"./electra-base-google-batch8-remove-noise-model/\"\n",
"\n",
"\n",
"total = 0\n",
"for i in range(0, 10):\n",
" r = random.randint(0, 10000)\n",
" \n",
" fifty_negative = df_orig.where(lambda x: x['sentiment'] == 'Negative').dropna().sample(n=50, random_state=r)\n",
" fifty_positive = df_orig.where(lambda x: x['sentiment'] == 'Positive').dropna().sample(n=50, random_state=r)\n",
"\n",
" new_df = pd.merge(fifty_negative, fifty_positive, on=['sentiment', 'review'], how='outer')\n",
" new_df.sentiment = new_df.sentiment.apply(lambda x: 1 if x == 'Positive' else 0)\n",
" fifty_negative = (\n",
" df_orig.where(lambda x: x[\"sentiment\"] == \"Negative\")\n",
" .dropna()\n",
" .sample(n=50, random_state=r)\n",
" )\n",
" fifty_positive = (\n",
" df_orig.where(lambda x: x[\"sentiment\"] == \"Positive\")\n",
" .dropna()\n",
" .sample(n=50, random_state=r)\n",
" )\n",
"\n",
" new_df = pd.merge(\n",
" fifty_negative, fifty_positive, on=[\"sentiment\", \"review\"], how=\"outer\"\n",
" )\n",
" new_df.sentiment = new_df.sentiment.apply(lambda x: 1 if x == \"Positive\" else 0)\n",
" X_all = new_df.review\n",
" y_all = new_df.sentiment\n",
" accuracy = gcr.call_model(X_all, y_all, model, device, accuracy=True)\n",
" total += accuracy\n",
" print('acc: {0:.4f}, seed: {1}, i: {2}'.format(accuracy, r, i))\n",
" print(\"acc: {0:.4f}, seed: {1}, i: {2}\".format(accuracy, r, i))\n",
"\n",
"\n",
" \n",
"print('Average accuracy: ', total/10)\n"
"print(\"Average accuracy: \", total / 10)"
]
},
{
Expand Down

0 comments on commit 95c287a

Please sign in to comment.