Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spaces-tagged data types #37

Open
ryanjulian opened this issue May 15, 2019 · 3 comments
Open

Add spaces-tagged data types #37

ryanjulian opened this issue May 15, 2019 · 3 comments

Comments

@ryanjulian
Copy link
Member

This would add tagged data types (and no-copy constructors) for the concrete types represented by a spaces.

This would allow users to perform important akro operations without needing an explicit reference to the space (and saves a lot of type checking). It would also get rid of a lot of helper/misc functions.

This idea is best illustrated by an example:

import akro

dspace = akro.Dict({'foo': akro.Discrete(3)})

d = {'foo': 1}  # a regular old dict
d_tagged = akro.tag(d, dspace)  # returns an akro.dict (inherits from dict) which stores dspace
d_flat = d_tagged.flatten()  # now I can flatten without any args
d_flat.to_tf_placeholder()  # sure, why not?
e = dspace.sample()  # returns an akro.dict by default)
@ryanjulian
Copy link
Member Author

I'm curious for @krzentner 's feedback on this proposal

@sud0nick
Copy link
Contributor

  1. In the example you gave what would d_tagged look like?
    1a. Does it contain both d and dspace as separate items in a dictionary?
    1b. If so, what would their keys be?
  2. When you say "# returns an akro.dict (inherits from dict) which stores dspace" are you referring to gym.spaces.Dict?

@ryanjulian
Copy link
Member Author

ryanjulian commented May 17, 2019

by the way this is just an idea -- i didn't mean for it to be an assignment (necessarily)

i'm not sure I totally conveyed what i meant. i will add more examples.

class AkroDict(dict):
    """ This should probably actually inherit from types.MappingProxyType to avoid copying"""

    @property
    def space(self):
        return self._space

    def flatten(self):
        return self._space.flatten(self)

    # etc...

def tag(data, space):
    if not space.contains(data):
        raise ValueError('Cannot tag {} with space {}. The space must contain the tagged data.'.format(data, space))

    if isinstance(data, dict):
        tagged = AkroDict(data)
    elif isinstance(data, np.ndarray):
        tagged = AkroArray(data)
    elif isinstance(data, tuple):
        # etc...

    tagged._space = space

    return tagged

then later

>>> obs = env.reset()  # obs is a dict
>>> print(obs) 
{'position': np.array([1; 2]), 'velocity': np.array([3; 4])}
>>> obs = akro.tag(obs, env.observation_space)  # obs is now an AkroDict
>>> print(obs['velocity'])  # still works fine, an AkroDict is a dict
np.array([3; 4])
>>> print(obs.flatten())  # flattens the keys. nifty. notice how the ';' (column marker) changes  to ',' (row marker) -- e.g. they are now 1D instead of 2D
{'position': np.array([1, 2]), 'velocity': np.array([3, 4])}
>>> print(obs.flatten()['velocity'])  # yep, this works too.
np.array([3, 4])
>>> ph = obs.flatten().to_tf_placeholder(batch_dims=1)  # sure, why not
>>> print(ph)
{'position': tf.placeholder(shape=(None, 2)), 'velocity': tf.placeholder(shape=(None, 2))}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants