diff --git a/.gitignore b/.gitignore
index edd980e..baf75a5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,3 +3,5 @@ strava
 *.pyc
 *.osm
 *.pkl
+.DS_Store
+.env
diff --git a/heatmap/__init__.py b/heatmap/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/download.py b/heatmap/download.py
similarity index 100%
rename from download.py
rename to heatmap/download.py
diff --git a/draw.py b/heatmap/draw.py
similarity index 92%
rename from draw.py
rename to heatmap/draw.py
index 30779ac..bfd008e 100755
--- a/draw.py
+++ b/heatmap/draw.py
@@ -15,13 +15,12 @@
 import osm
 
 # TODO: move to argparse
-use_osm = True
 osm_color = "salmon"
 osm_line_width = .1
 osm_alpha = .5
 
 
-def plot(data, background_color, line_width, line_color, line_alpha, dpi, label=0):
+def plot(data, background_color, line_width, line_color, line_alpha, dpi, use_osm, label=0):
     if line_color.startswith("cmap:"):
         use_cmap = True
         max_elev = max([max(d["elevs"]) for d in data])
@@ -115,17 +114,16 @@ def load_gpx(files, data=None):
             gpx = gpxpy.parse(f)
 
         track = gpx.tracks[0]
-        segment = track.segments[0]
-
-        data["tracks"].append({
-            "lats": np.array([p.latitude for p in segment.points]),
-            "lons": np.array([p.longitude for p in segment.points]),
-            "elevs": np.array([p.elevation for p in segment.points]),
-            "type": int(track.type),
-            "name": track.name,
-            "date": gpx.time,
-            "filename": os.path.basename(path)
-        })
+        for segment in track.segments:
+            data["tracks"].append({
+                "lats": np.array([p.latitude for p in segment.points]),
+                "lons": np.array([p.longitude for p in segment.points]),
+                "elevs": np.array([p.elevation for p in segment.points]),
+                "type": int(track.type),
+                "name": track.name,
+                "date": gpx.time,
+                "filename": os.path.basename(path)
+            })
     print(f"loaded {len(data)} file(s)")
     file_set = set(os.path.basename(f) for f in files)
     if "files" in data:
@@ -156,6 +154,8 @@ def add_shared_args(parser):
         help="if defined only include this activity type")
     parser.add_argument("--gpx-dir", default="strava",
         help="directory with gpx files")
+    parser.add_argument("--use-osm", default=False, action="store_true",
+        help="overlay heatmap on top of OpenStreetMap")
 
 
 parser = ArgumentParser()
@@ -179,7 +179,7 @@ def add_shared_args(parser):
 
 args = parser.parse_args()
 
-plot_keys = ["background_color", "line_color", "line_width", "line_alpha", "dpi"]
+plot_keys = ["background_color", "line_color", "line_width", "line_alpha", "dpi", "use_osm"]
 plot_args = {k: getattr(args, k) for k in plot_keys}
 
 cache_path = os.path.join(args.gpx_dir, "cache.pkl")
@@ -216,7 +216,7 @@ def add_shared_args(parser):
     coords = np.array([[np.average(d["lats"][0]), np.average(d["lons"][0])] for d in data])
 
 if args.type == "cluster":
-    cluster = DBSCAN(eps=args.radius, min_samples=10)
+    cluster = DBSCAN(eps=args.radius, min_samples=args.min_cluster_size)
     cluster.fit(coords)
     n_clusters = np.max(cluster.labels_) + 1
     centroids = [np.mean(coords[cluster.labels_ == l], axis=0) for l in range(n_clusters)]
diff --git a/osm.py b/heatmap/osm.py
similarity index 100%
rename from osm.py
rename to heatmap/osm.py
diff --git a/heatmap/rogue.py b/heatmap/rogue.py
new file mode 100755
index 0000000..060c11d
--- /dev/null
+++ b/heatmap/rogue.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+from gpxpy import parse
+from gpxpy.gpx import GPXTrackSegment
+from math import atan2, cos, radians, sin, sqrt
+from glob import glob
+import os
+import multiprocessing
+
+
+def break_segment(segment, break_points): 
+    new = [segment]
+    if len(break_points) != 0:
+        i = 0
+        for b in break_points:
+            old = new.pop()
+            new1, new2 = old.split(b-i)
+            if len(new1.points) != 0:
+                new.append(new1)
+            if len(new2.points) == 0:
+                return new
+            new.append(new2)
+            i += b
+    return new
+
+
+def find_breaks(segment, max_dist):
+    breaks = []
+    for i in range(len(segment.points)-1):
+        if distance(segment.points[i+1], segment.points[i]) > max_dist:
+            breaks.append(i)
+    return breaks
+
+
+def fix_segments(segments, max_dist):
+    s = []
+    for segment in segments:
+        s.extend(break_segment(segment, find_breaks(segment, max_dist)))
+    return s
+
+
+def distance(origin, destination):
+    lat1, lon1 = origin.latitude, origin.longitude
+    lat2, lon2 = destination.latitude, destination.longitude
+    radius = 6371  # km
+    dlat = radians(lat2-lat1)
+    dlon = radians(lon2-lon1)
+    a = sin(dlat/2) * sin(dlat/2) + cos(radians(lat1)) \
+        * cos(radians(lat2)) * sin(dlon/2) * sin(dlon/2)
+    c = 2 * atan2(sqrt(a), sqrt(1-a))
+    d = radius * c
+    return d
+
+
+def disjointed(filename):
+    with open(filename, 'r') as f:
+        gpx = parse(f)
+    old_segments = gpx.tracks[0].segments
+    new_segments = fix_segments(old_segments, 0.1)
+    if len(old_segments) == len(new_segments):
+        return None
+    gpx.tracks[0].segments = new_segments
+    with open(filename, 'w') as f:
+        f.write(gpx.to_xml())
+    return filename
+
+
+def main():
+    files = glob("strava/*.gpx")
+    p = multiprocessing.Pool(multiprocessing.cpu_count())
+    results = p.map(disjointed, files)
+    bad = [r for r in results if r]
+    if len(bad) > 0:
+        if os.path.isfile("strava/cache.pkl"):
+            os.remove("strava/cache.pkl")
+        print(bad)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_rogue.py b/tests/test_rogue.py
new file mode 100644
index 0000000..b9d0b76
--- /dev/null
+++ b/tests/test_rogue.py
@@ -0,0 +1,136 @@
+from heatmap import rogue
+import unittest
+from unittest.mock import patch, MagicMock
+from gpxpy.gpx import GPXTrackPoint, GPXTrackSegment
+from typing import Any, List
+
+
+def fake_distance(p1, p2):
+    return abs(p2.latitude - p1.latitude)
+
+
+def generate_segment(lats):
+    points = [GPXTrackPoint(latitude=x) for x in lats]
+    return GPXTrackSegment(points=points)
+
+
+# This is dumb
+def equals(object1: Any, object2: Any, ignore: Any=None) -> bool:
+    """ Testing purposes only """
+
+    if not object1 and not object2:
+        return True
+
+    if not object1 or not object2:
+        print('Not obj2')
+        return False
+
+    if not object1.__class__ == object2.__class__:
+        print('Not obj1')
+        return False
+
+    if type(object1) == type(object2) == type([]):
+        if len(object1) != len(object2):
+            return False
+        for i in range(len(object1)):
+            if not equals(object1[i], object2[i]):
+                return False
+        return True
+
+    attributes: List[str] = []
+    for attr in dir(object1):
+        if not ignore or attr not in ignore:
+            if not hasattr(object1, '__call__') and not attr.startswith('_'):
+                if attr not in attributes:
+                    attributes.append(attr)
+
+    for attr in attributes:
+        attr1 = getattr(object1, attr)
+        attr2 = getattr(object2, attr)
+
+        if attr1 == attr2:
+            return True
+
+        if not attr1 and not attr2:
+            return True
+        if not attr1 or not attr2:
+            print(f'Object differs in attribute {attr} ({attr1} - {attr2})')
+            return False
+
+        if not equals(attr1, attr2):
+            print(f'Object differs in attribute {attr} ({attr1} - {attr2})')
+            return False
+
+    return True
+
+
+class TestBreaks(unittest.TestCase):
+
+    def test_break_segment_with_1_element(self):
+        lats = [1]
+        segment = generate_segment(lats)
+        result = rogue.break_segment(segment, [])
+        self.assertEqual(result, [segment])
+
+    def test_break_segment_with_2_elements_no_breaks(self):
+        lats = [1, 2]
+        segment = generate_segment(lats)
+        result = rogue.break_segment(segment, [])
+        self.assertEqual(result, [segment])
+
+    def test_break_segment_with_2_elements_1_break(self):
+        lats = [1, 3]
+        segment = generate_segment(lats)
+        result = rogue.break_segment(segment, [0])
+        correct = [generate_segment([1]), generate_segment([3])]
+        self.assertTrue(equals(result, correct))
+
+    @patch('heatmap.rogue.distance', MagicMock(side_effect=fake_distance))
+    def test_find_breaks_with_1_span(self):
+        lats = [1, 2, 3, 4, 5]
+        segment = generate_segment(lats)
+        result = rogue.find_breaks(segment, 1)
+        self.assertEqual(result, [])
+
+    @patch('heatmap.rogue.distance', MagicMock(side_effect=fake_distance))
+    def test_find_breaks_with_2_spans(self):
+        lats = [1, 2, 3, 5, 6, 7]
+        segment = generate_segment(lats)
+        result = rogue.find_breaks(segment, 1)
+        self.assertEqual(result, [2])
+
+    @patch('heatmap.rogue.distance', MagicMock(side_effect=fake_distance))
+    def test_find_breaks_with_3_spans(self):
+        lats = [1, 2, 3, 5, 6, 7, 9, 10, 11]
+        segment = generate_segment(lats)
+        result = rogue.find_breaks(segment, 1)
+        self.assertEqual(result, [2, 5])
+
+    @patch('heatmap.rogue.distance', MagicMock(side_effect=fake_distance))
+    def test_fix_segments_with_1_span(self):
+        lats = [1, 2, 3, 4, 5]
+        segment = generate_segment(lats)
+        result = rogue.fix_segments([segment], 1)
+        self.assertTrue(equals(result, [segment]))
+
+    @patch('heatmap.rogue.distance', MagicMock(side_effect=fake_distance))
+    def test_fix_segments_with_2_spans(self):
+        lats = [1, 2, 3, 5, 6, 7]
+        segment = generate_segment(lats)
+        result = rogue.fix_segments([segment], 1)
+        correct = [generate_segment([1, 2, 3]), generate_segment([5, 6, 7])]
+        self.assertTrue(equals(result, correct))
+
+    @patch('heatmap.rogue.distance', MagicMock(side_effect=fake_distance))
+    def test_fix_segments_with_3_spans(self):
+        lats = [1, 2, 3, 5, 6, 7, 9, 10, 11]
+        segment = generate_segment(lats)
+        result = rogue.fix_segments([segment], 1)
+        correct = [generate_segment([1, 2, 3]),
+                   generate_segment([5, 6, 7]),
+                   generate_segment([9, 10, 11])]
+        self.assertTrue(equals(result, correct))
+
+
+if __name__ == "__main__":
+    unittest.main()