forked from foxtrotmike/CS909
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotit.py
113 lines (100 loc) · 4.42 KB
/
plotit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
"""
@author: Dr. Fayyaz Minhas
@author-email: afsar at pieas dot edu dot pk
2D Scatter Plotter for Classification
"""
from numpy.random import randn #importing randn
import numpy as np #importing numpy
import matplotlib.pyplot as plt #importing plotting module
import itertools
import warnings
def plotit(X,Y=None,clf=None, conts = None, ccolors = ('b','k','r'), colors = ('c','y'), markers = ('s','o'), hold = False, transform = None,extent = None,**kwargs):
"""
A function for showing data scatter plot and classification boundary
of a classifier for 2D data
X: nxd matrix of data points
Y: (optional) n vector of class labels
clf: (optional) classification/discriminant function handle
conts: (optional) contours (if None, contours are drawn for each class boundary)
ccolors: (optional) colors for contours
colors: (optional) colors for each class (sorted wrt class id)
can be 'scaled' or 'random' or a list/tuple of color ids
markers: (optional) markers for each class (sorted wrt class id)
hold: Whether to hold the plot or not for overlay (default: False).
transform: (optional) a function handle for transforming data before passing to clf
kwargs: any keyword arguments to be passed to clf (if any)
"""
if clf is not None and X.shape[1]!=2:
warnings.warn("Data Dimensionality is not 2. Unable to plot.")
return
if markers is None:
markers = ('.',)
eps=1e-6
d0,d1 = (0,1)
if extent is None:
minx, maxx = np.min(X[:,d0])-eps, np.max(X[:,d0])+eps
miny, maxy = np.min(X[:,d1])-eps, np.max(X[:,d1])+eps
extent = [minx,maxx,miny,maxy]
else:
[minx,maxx,miny,maxy] = extent
if Y is not None:
classes = sorted(set(Y))
if conts is None or len(conts)<2:
#conts = list(classes)
vmin,vmax = classes[0]-eps,classes[-1]+eps
else:
vmin,vmax= np.min(conts)-eps,np.max(conts)+eps
else:
vmin,vmax=-2-eps,2+eps
if conts is None or len(conts)<2:
conts = sorted([-1+eps,0,1-eps])
else:
vmin,vmax= np.min(conts)-eps,np.max(conts)+eps
if clf is not None:
npts = 150
x = np.linspace(minx,maxx,npts)
y = np.linspace(miny,maxy,npts)
t = np.array(list(itertools.product(x,y)))
if transform is not None:
t = transform(t)
z = clf(t,**kwargs)
z = np.reshape(z,(npts,npts)).T
plt.contour(x,y,z,conts,linewidths = [2],colors=ccolors,extent=extent, label='f(x)=0')
#plt.imshow(np.flipud(z), extent = extent, cmap=plt.cm.Purples, vmin = -2, vmax = +2); plt.colorbar()
plt.pcolormesh(x, y, z,cmap=plt.cm.Purples,vmin=vmin,vmax=vmax);plt.colorbar()
plt.axis(extent)
if Y is not None:
for i,y in enumerate(classes):
if colors is None or colors=='scaled':
cc = np.array([[i,i,i]])/float(len(classes))
elif colors =='random':
cc = np.array([[np.random.rand(),np.random.rand(),np.random.rand()]])
else:
cc = colors[i%len(colors)]
mm = markers[i%len(markers)]
plt.scatter(X[Y==y,d0],X[Y==y,d1], marker = mm,c = cc, s = 50)
else:
plt.scatter(X[:,d0],X[:,d1],marker = markers[0], c = 'k', s = 5)
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
if not hold:
plt.grid()
plt.show()
return extent
def getExamples(n=100,d=2):
"""
Generates n d-dimensional normally distributed examples of each class
The mean of the positive class is [1] and for the negative class it is [-1]
"""
Xp = randn(n,d)#+1 #generate n examples of the positie class
Xp=Xp+1
Xn = randn(n,d)#-1 #generate n examples of the negative class
Xn=Xn-1
X = np.vstack((Xp,Xn)) #Stack the examples together to a single matrix
Y = np.array([+1]*n+[-1]*n) #Associate Labels
return (X,Y)
if __name__ == '__main__':
X,Y = getExamples()
clf = lambda x: 2*np.sum(x,axis=1)-2.5 #dummy classifier
plotit(X = X, Y = Y, clf = clf, conts =[-1,0,1], colors = 'random')