Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal TF2 support #1026

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- EvalCallback now works also for recurrent policies (@mily20001)
- Add minimal support for TF2 using tensorflow.compat.v1 while keeping support for TF1

Bug Fixes:
^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf

from stable_baselines import logger
from stable_baselines.common import explained_variance, tf_util, ActorCriticRLModel, SetVerbosity, TensorboardWriter
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
from gym.spaces import Discrete, Box
from collections import deque

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/acktr/acktr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import warnings

import tensorflow as tf
import tensorflow.compat.v1 as tf
from gym.spaces import Box, Discrete

from stable_baselines import logger
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/acktr/kfac.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from functools import reduce

import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np

from stable_baselines.acktr.kfac_utils import detect_min_val, factor_reshape, gmatmul
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/acktr/kfac_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf


def gmatmul(tensor_a, tensor_b, transpose_a=False, transpose_b=False, reduce_dim=None):
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import gym
import cloudpickle
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf

from stable_baselines.common.misc_util import set_global_seeds
from stable_baselines.common.save_util import data_to_json, json_to_data, params_to_bytes, bytes_to_params
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.python.ops import math_ops
from gym import spaces

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete


Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/misc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf


def zipsame(*seqs):
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/mpi_adam.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np
import mpi4py

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/mpi_running_mean_std.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import mpi4py
import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np

import stable_baselines.common.tf_util as tf_util
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
from gym.spaces import Discrete

from stable_baselines.common.tf_util import batch_to_seq, seq_to_batch
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/tf_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf


def ortho_init(scale=1.0):
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/common/tf_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Set

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf


def is_image(tensor):
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
import tensorflow.contrib as tc
from mpi4py import MPI

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/ddpg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

import gym
import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np
from mpi4py import MPI

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/ddpg/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf
from gym.spaces import Box

from stable_baselines.common.policies import BasePolicy, nature_cnn, register_policy
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/deepq/build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
Q' is set to Q once every 10000 updates training steps.

"""
import tensorflow as tf
import tensorflow.compat.v1 as tf
from gym.spaces import MultiDiscrete

from stable_baselines.common import tf_util
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial

import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np
import gym

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/deepq/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf
import tensorflow.contrib.layers as tf_layers
import numpy as np
from gym.spaces import Discrete
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/gail/adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
I follow the architecture from the official repository
"""
import gym
import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np

from stable_baselines.common.mpi_running_mean_std import RunningMeanStd
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import defaultdict
from typing import Optional

import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
Expand Down Expand Up @@ -715,7 +715,7 @@ def read_tb(path):
import numpy as np
from glob import glob
# from collections import defaultdict
import tensorflow as tf
import tensorflow.compat.v1 as tf
if os.path.isdir(path):
fnames = glob(os.path.join(path, "events.*"))
elif os.path.basename(path).startswith("events."):
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/ppo1/pposgd_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
from mpi4py import MPI

from stable_baselines.common import Dataset, explained_variance, fmt_row, zipsame, ActorCriticRLModel, SetVerbosity, \
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf

from stable_baselines import logger
from stable_baselines.common import explained_variance, ActorCriticRLModel, tf_util, SetVerbosity, TensorboardWriter
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/sac/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np
from gym.spaces import Box

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf

from stable_baselines.common import tf_util, OffPolicyRLModel, SetVerbosity, TensorboardWriter
from stable_baselines.common.vec_env import VecEnv
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/td3/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf
from gym.spaces import Box

from stable_baselines.common.policies import BasePolicy, nature_cnn, register_policy
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf

from stable_baselines import logger
from stable_baselines.common import tf_util, OffPolicyRLModel, SetVerbosity, TensorboardWriter
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/trpo_mpi/trpo_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import gym
from mpi4py import MPI
import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np

import stable_baselines.common.tf_util as tf_util
Expand Down
2 changes: 1 addition & 1 deletion tests/test_a2c_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf

from stable_baselines.common.tf_layers import conv
from stable_baselines.common.input import observation_input
Expand Down
2 changes: 1 addition & 1 deletion tests/test_custom_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gym
import pytest
import tensorflow as tf
import tensorflow.compat.v1 as tf

from stable_baselines import A2C, ACER, ACKTR, DQN, PPO1, PPO2, TRPO, SAC, DDPG
from stable_baselines.common.policies import FeedForwardPolicy
Expand Down
2 changes: 1 addition & 1 deletion tests/test_distri.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf

import stable_baselines.common.tf_util as tf_util
from stable_baselines.common.distributions import DiagGaussianProbabilityDistributionType,\
Expand Down
2 changes: 1 addition & 1 deletion tests/test_math_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy as np
from gym.spaces.box import Box

Expand Down
2 changes: 1 addition & 1 deletion tests/test_tf_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# tests for tf_util
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf

from stable_baselines.common.tf_util import function, initialize, single_threaded_session, is_image

Expand Down