-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
73 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import sys | ||
|
||
import tensorflow as tf | ||
import unreal_engine as ue | ||
|
||
class TFPluginAPI(): | ||
@classmethod | ||
def getInstance(cls): | ||
#This should return an instance of your class even if you subclassed it | ||
return cls() | ||
|
||
def __init__(self): | ||
#class scoped variable for stopping | ||
self.shouldstop = False | ||
self.stored = {} | ||
|
||
#expected api: setup your model for training | ||
def setup(self): | ||
#setup or load your model and pass it into stored | ||
|
||
#Usually store session, graph, and model if using keras | ||
#self.sess = tf.InteractiveSession() | ||
#self.graph = tf.get_default_graph() | ||
pass | ||
|
||
#expected api: storedModel and session, json inputs | ||
def runJsonInput(self, jsonInput): | ||
#e.g. our json input could be a pixel array | ||
#pixelarray = jsonInput['pixels'] | ||
|
||
#run input on your graph | ||
#e.g. sess.run(model['y'], feed_dict) | ||
# where y is your result graph and feed_dict is {x:[input]} | ||
|
||
#... | ||
|
||
#return a json you will parse e.g. a prediction | ||
result = {} | ||
result['prediction'] = 0 | ||
|
||
return result | ||
|
||
#expected api: early stopping | ||
def stop(self): | ||
self.shouldstop = True | ||
|
||
#expected api: no params forwarded for training? TBC | ||
def train(self): | ||
#train here | ||
|
||
#... | ||
|
||
#inside your training loop check if we should stop early | ||
#if(this.shouldstop): | ||
# break | ||
pass | ||
|
||
#required function to get our api | ||
def getApi(): | ||
#return CLASSNAME.getInstance() | ||
return TFPluginAPI.getInstance() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters