Skip to content

Commit

Permalink
class based api
Browse files Browse the repository at this point in the history
  • Loading branch information
getnamo committed May 18, 2017
1 parent cb0a92c commit 40b3084
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 5 deletions.
61 changes: 61 additions & 0 deletions Content/Scripts/TFPluginAPI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import sys

import tensorflow as tf
import unreal_engine as ue

class TFPluginAPI():
@classmethod
def getInstance(cls):
#This should return an instance of your class even if you subclassed it
return cls()

def __init__(self):
#class scoped variable for stopping
self.shouldstop = False
self.stored = {}

#expected api: setup your model for training
def setup(self):
#setup or load your model and pass it into stored

#Usually store session, graph, and model if using keras
#self.sess = tf.InteractiveSession()
#self.graph = tf.get_default_graph()
pass

#expected api: storedModel and session, json inputs
def runJsonInput(self, jsonInput):
#e.g. our json input could be a pixel array
#pixelarray = jsonInput['pixels']

#run input on your graph
#e.g. sess.run(model['y'], feed_dict)
# where y is your result graph and feed_dict is {x:[input]}

#...

#return a json you will parse e.g. a prediction
result = {}
result['prediction'] = 0

return result

#expected api: early stopping
def stop(self):
self.shouldstop = True

#expected api: no params forwarded for training? TBC
def train(self):
#train here

#...

#inside your training loop check if we should stop early
#if(this.shouldstop):
# break
pass

#required function to get our api
def getApi():
#return CLASSNAME.getInstance()
return TFPluginAPI.getInstance()
17 changes: 12 additions & 5 deletions Content/Scripts/TensorFlowComponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@ def begin_play(self):
if(self.uobject.VerbosePythonLog):
ue.log('BeginPlay, importing TF module: ' + self.uobject.TensorFlowModule)

self.tf = importlib.import_module(self.uobject.TensorFlowModule)
imp.reload(self.tf)
#import the module
self.tfModule = importlib.import_module(self.uobject.TensorFlowModule)
imp.reload(self.tfModule)

#tfc or the class instance holding the pluginAPI
self.tfapi = self.tfModule.getApi()

#init
self.tfapi.setup()

#train
if(self.uobject.ShouldTrainOnBeginPlay):
Expand All @@ -30,15 +37,15 @@ def begin_play(self):

def end_play(self):
self.ValidGameWorld = False
self.tf.stop()
self.tfapi.stop()

#tensor input
def tensorinput(self, args):
if(self.uobject.VerbosePythonLog):
ue.log(self.uobject.TensorFlowModule + ' input passed: ' + args)

#pass the raw json to the script to handle
resultJson = self.tf.runJsonInput(self.trained, json.loads(args))
resultJson = self.tfapi.runJsonInput(json.loads(args))

#pass prediction json back
self.uobject.OnResultsFunction(json.dumps(resultJson))
Expand All @@ -57,7 +64,7 @@ def trainBlocking(self):

#calculate the time it takes to train your network
start = time.time()
self.trained = self.tf.train()
self.trained = self.tfapi.train()
stop = time.time()

if self.trained is None:
Expand Down

0 comments on commit 40b3084

Please sign in to comment.