Skip to content

Commit

Permalink
added gif creator
Browse files Browse the repository at this point in the history
  • Loading branch information
tawnkramer committed Oct 22, 2018
1 parent 4e67ce1 commit 47c8876
Showing 1 changed file with 39 additions and 12 deletions.
51 changes: 39 additions & 12 deletions examples/supervised_learning/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,39 @@
from donkeycar.utils import linear_unbin
import conf

class GifCreator(object):

def __init__(self, filename):
import imageio
self.filename = filename
self.images = []
self.every_nth_frame = 4
self.i_frame = 0

def add_image(self, image):
self.i_frame += 1
if self.i_frame % self.every_nth_frame == 0:
self.images.append(image)

def close(self):
import imageio
if len(self.images) > 0:
print('writing movie', self.filename)
imageio.mimsave(self.filename, self.images)


class DonkeySimMsgHandler(IMesgHandler):

STEERING = 0
THROTTLE = 1

def __init__(self, model, constant_throttle):
def __init__(self, model, constant_throttle, movie_handler=None):
self.model = model
self.constant_throttle = constant_throttle
self.sock = None
self.timer = FPSTimer()
self.image_folder = None
self.movie_handler = movie_handler
self.fns = {'telemetry' : self.on_telemetry}

def on_connect(self, socketHandler):
Expand All @@ -62,11 +84,9 @@ def on_telemetry(self, data):
image_array = np.asarray(image)
self.predict(image_array)

# maybe save frame
if self.image_folder is not None:
timestamp = datetime.utcnow().strftime('%Y_%m_%d_%H_%M_%S_%f')[:-3]
image_filename = os.path.join(self.image_folder, timestamp)
image.save('{}.jpg'.format(image_filename))
# maybe write movie
if self.movie_handler is not None:
self.movie_handler.add_image(image_array)


def predict(self, image_array):
Expand Down Expand Up @@ -112,20 +132,25 @@ def send_control(self, steer, throttle):
self.sock.queue_message(msg)


def on_close(self):
pass
def on_disconnect(self):
if self.movie_handler:
self.movie_handler.close()



def go(filename, address, constant_throttle):
def go(filename, address, constant_throttle, gif):

model = load_model(filename)

#In this mode, looks like we have to compile it
model.compile("sgd", "mse")

movie_handler = None

if gif != "none":
movie_handler = GifCreator(gif)

#setup the server
handler = DonkeySimMsgHandler(model, constant_throttle)
handler = DonkeySimMsgHandler(model, constant_throttle, movie_handler)
server = SimServer(address, handler)

try:
Expand All @@ -140,7 +165,9 @@ def go(filename, address, constant_throttle):
parser = argparse.ArgumentParser(description='prediction server')
parser.add_argument('--model', type=str, help='model filename')
parser.add_argument('--constant_throttle', type=float, default=0.0, help='apply constant throttle')
parser.add_argument('--gif', type=str, default="none", help='make animated gif of evaluation')

args = parser.parse_args()

address = ('0.0.0.0', 9091)
go(args.model, address, args.constant_throttle)
go(args.model, address, args.constant_throttle, args.gif)

0 comments on commit 47c8876

Please sign in to comment.