-
Notifications
You must be signed in to change notification settings - Fork 0
/
mean_shift.py
37 lines (28 loc) · 972 Bytes
/
mean_shift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')
fig, ax = plt.subplots()
norm = lambda x1, x2: np.linalg.norm(x1-x2)
def draw_circle(x=0, y=0, r=1):
circle = plt.Circle( (x, y), r, alpha=0.35)
ax.add_artist(circle)
def neighbors(point, X, distance):
friends = []
for x in X:
if( norm(point, x) <= distance):
friends.append(x)
return friends
def kernel(distance, bandwidth):
return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth)) ** 2)
xs = np.arange(1, 51) # xs [1-50]
ys = 50 * np.random.random((1, 50))[0] # ys randoms [0-1]
X = np.column_stack((xs, ys)) # X is featureset [xs, ys]
class Mean_Shift:
def __init__(self, bandwidth=4):
self.bandwidth = bandwidth
def fit(self, data):
centroids = {}
# make all data points centroids
for i in range(len(data)):
centroids[i] = data[i]