Skip to content

Commit 286f581

Browse files
committed
Tests: reproducibility, CLI plot behavior, and smoke predict for text sentiment example
1 parent 2d805ae commit 286f581

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import sys
3+
import numpy as np
4+
import pytest
5+
6+
datasets = pytest.importorskip("datasets")
7+
8+
# Import the example as a module (pytest adds repo root to sys.path)
9+
from examples.text_sentiment_svm_with_resampling import (
10+
main,
11+
load_tweet_eval,
12+
)
13+
14+
@pytest.mark.filterwarnings("ignore::UserWarning")
15+
def test_loader_reproducible_small():
16+
"""Same seed -> identical splits (reproducibility)."""
17+
X1, y1, Xt1, Yt1 = load_tweet_eval(max_samples=800, random_state=42)
18+
X2, y2, Xt2, Yt2 = load_tweet_eval(max_samples=800, random_state=42)
19+
assert X1 == X2
20+
assert np.array_equal(y1, y2)
21+
assert Xt1 == Xt2
22+
assert np.array_equal(Yt1, Yt2)
23+
24+
@pytest.mark.filterwarnings("ignore::UserWarning")
25+
def test_smoke_predicts_labels_small():
26+
"""End-to-end: pipeline trains and predicts on a tiny slice."""
27+
Xtr, ytr, Xte, yte = load_tweet_eval(max_samples=800, random_state=0)
28+
# Build the same pipeline as in the example
29+
from imblearn.pipeline import Pipeline
30+
from imblearn.under_sampling import RandomUnderSampler
31+
from sklearn.feature_extraction.text import TfidfVectorizer
32+
from sklearn.svm import LinearSVC
33+
34+
pipe = Pipeline([
35+
("tfidf", TfidfVectorizer(min_df=2, ngram_range=(1, 2))),
36+
("balance", RandomUnderSampler(random_state=0)),
37+
("clf", LinearSVC()),
38+
])
39+
pipe.fit(Xtr, ytr)
40+
pred = pipe.predict(Xte)
41+
assert len(pred) == len(yte)
42+
# Predictions must be in the expected label set {0,1,2}
43+
assert set(np.unique(pred)).issubset({0, 1, 2})
44+
45+
@pytest.mark.filterwarnings("ignore::UserWarning")
46+
def test_cli_saves_plot(tmp_path):
47+
"""CLI: --plot should create the confusion matrix image."""
48+
out = tmp_path / "cm.png"
49+
main(["--plot", "--max-samples", "800", "--output", str(out)])
50+
assert out.exists() and out.stat().st_size > 0
51+
52+
@pytest.mark.filterwarnings("ignore::UserWarning")
53+
def test_cli_no_plot_no_file(tmp_path):
54+
"""CLI: without --plot, no image should be created."""
55+
out = tmp_path / "cm.png"
56+
if out.exists():
57+
os.remove(out)
58+
main(["--max-samples", "500", "--output", str(out)])
59+
assert not out.exists()

0 commit comments

Comments
 (0)