Skip to content

Commit

Permalink
Upgrade to TensorFlow 2 (#40)
Browse files Browse the repository at this point in the history
Changes our uses of TensorFlow APIs to be compatible with TensorFlow 2, mostly automatically through the tf_upgrade_v2 command. This change has been validated to work with all capsules in the Capsule Zoo as well as our internally developed capsules at Aotu.ai.

This aims to make OpenVisionCapsules TensorFlow 2 compatible without any features disabled, but does not attempt to upgrade our use of old v1 APIs.

While this upgrade requires no changes to the API of vcap or vcap-utils, it is still a breaking change for any capsule that has to use TensorFlow directly.

Fixes #39
  • Loading branch information
velovix authored Jul 20, 2021
1 parent 5506c2c commit 77e6304
Show file tree
Hide file tree
Showing 9 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion vcap/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"scipy==1.4.1",
"scikit-learn==0.22.2",
"numpy>=1.16,<2",
"tensorflow-gpu==1.15.4",
"tensorflow~=2.5.0",
],
extras_require={
"tests": test_packages,
Expand Down
2 changes: 1 addition & 1 deletion vcap/vcap/device_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_all_devices() -> List[str]:
#
# TODO: Use tf.config.list_physical_devices in TF 2.1

with tf.Session():
with tf.compat.v1.Session():
all_devices = device_lib.list_local_devices()

# Get the device names and remove duplicates, just in case...
Expand Down
2 changes: 1 addition & 1 deletion vcap_utils/vcap_utils/backends/crowd_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CrowdDensityCounter(BaseTFBackend):

def __init__(self, model_bytes,
device: str=None,
session_config: tf.ConfigProto=None):
session_config: tf.compat.v1.ConfigProto=None):
"""
:param model_bytes: Model file data, likely a loaded *.pb file
:param device: The device to run the model on
Expand Down
2 changes: 1 addition & 1 deletion vcap_utils/vcap_utils/backends/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DepthPredictor(BaseTFBackend):

def __init__(self, model_bytes,
device: str=None,
session_config: tf.ConfigProto=None):
session_config: tf.compat.v1.ConfigProto=None):
"""
:param model_bytes: Model file data, likely a loaded *.pb file
:param device: The device to run the model on
Expand Down
10 changes: 5 additions & 5 deletions vcap_utils/vcap_utils/backends/load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def parse_tf_model_bytes(model_bytes,
device: str = None,
session_config: tf.ConfigProto = None):
session_config: tf.compat.v1.ConfigProto = None):
"""
:param model_bytes: The bytes of the model to load
Expand All @@ -18,7 +18,7 @@ def parse_tf_model_bytes(model_bytes,
detection_graph = tf.Graph()
with detection_graph.as_default():
# Load a (frozen) Tensorflow model from memory
graph_def = tf.GraphDef()
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(model_bytes)

with tf.device(device):
Expand All @@ -29,16 +29,16 @@ def parse_tf_model_bytes(model_bytes,
name='')

if session_config is None:
session_config = tf.ConfigProto()
session_config = tf.compat.v1.ConfigProto()

if device is not None:
# allow_soft_placement lets us remap GPU only ops to GPU, and doesn't
# crash for non-gpu only ops (it will place those on CPU, instead)
session_config.allow_soft_placement = True

# Create a session for later use
persistent_sess = tf.Session(graph=detection_graph,
config=session_config)
persistent_sess = tf.compat.v1.Session(graph=detection_graph,
config=session_config)

return detection_graph, persistent_sess

Expand Down
2 changes: 1 addition & 1 deletion vcap_utils/vcap_utils/backends/openface_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class OpenFaceEncoder(BaseEncoderBackend, BaseTFBackend):

def __init__(self, model_bytes, model_name,
device: str = None,
session_config: tf.ConfigProto = None):
session_config: tf.compat.v1.ConfigProto = None):
"""
:param model_bytes: Model file bytes, a loaded *.pb file
:param model_name: The name of the model in order to load correct
Expand Down
2 changes: 1 addition & 1 deletion vcap_utils/vcap_utils/backends/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Segmenter(BaseTFBackend):

def __init__(self, model_bytes, metadata_bytes,
device: str = None,
session_config: tf.ConfigProto = None):
session_config: tf.compat.v1.ConfigProto = None):
"""
:param model_bytes: Model file data, likely a loaded *.pb file
:param metadata_bytes: The dataset metadata file data, likely named
Expand Down
4 changes: 2 additions & 2 deletions vcap_utils/vcap_utils/backends/tf_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class TFImageClassifier(BaseTFBackend):
def __init__(self, model_bytes, metadata_bytes, model_name,
device: str = None,
session_config: tf.ConfigProto = None):
session_config: tf.compat.v1.ConfigProto = None):
"""
:param model_bytes: Loaded model data, likely from a *.pb file
:param metadata_bytes: Loaded dataset metadata, likely from a file
Expand All @@ -41,7 +41,7 @@ def __init__(self, model_bytes, metadata_bytes, model_name,
# Create the input node to the graph, with preprocessing built-in
with self.graph.as_default():
# Create a new input node for images of various sizes
self.input_node = tf.placeholder(
self.input_node = tf.compat.v1.placeholder(
dtype=tf.float32,
shape=[None, self.config.img_size, self.config.img_size, 3])

Expand Down
2 changes: 1 addition & 1 deletion vcap_utils/vcap_utils/backends/tf_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TFObjectDetector(BaseTFBackend):
def __init__(self, model_bytes, metadata_bytes,
confidence_thresh=0.05,
device: str = None,
session_config: tf.ConfigProto = None):
session_config: tf.compat.v1.ConfigProto = None):
"""
:param model_bytes: Model file data, likely a loaded *.pb file
:param metadata_bytes: The dataset metadata file data, likely named
Expand Down

0 comments on commit 77e6304

Please sign in to comment.