Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Made CrossValTypes, HoldoutValTypes to have split functions directly #108

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

nabenabe0928
Copy link
Contributor

While maintaining the changes as small as possible, I made the changes.

autoPyTorch/datasets/resampling_strategy.py Outdated Show resolved Hide resolved
class HoldoutValTypes(Enum):
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""
holdout_validation = partial(HoldoutValFuncs.holdout_validation)
stratified_holdout_validation = partial(HoldoutValFuncs.stratified_holdout_validation)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Major change: IntEnum -> Enum and holding functions directly


def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
) -> Tuple[np.ndarray, np.ndarray]:
self.value(val_share=val_share, indices=indices, stratify=stratify)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we can call the function directly in a way that HoldoutValTypes.holdout_validation().

autoPyTorch/datasets/resampling_strategy.py Outdated Show resolved Hide resolved
autoPyTorch/datasets/resampling_strategy.py Outdated Show resolved Hide resolved
def __call__(self,
num_splits: int,
indices: np.ndarray,
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
...


class HoldOutFunc(Protocol):
class HoldoutValFunc(Protocol):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we often use holdout_validator, I unified the name.

@nabenabe0928 nabenabe0928 added the refactoring Improvement of readability and abstract codes label Feb 26, 2021
@nabenabe0928 nabenabe0928 added this to In progress in Roadmap Feb 26, 2021
@nabenabe0928 nabenabe0928 force-pushed the refactoring-base-dataset_splitting-functions_major-change branch from 62b326e to c7fd2d5 Compare March 15, 2021 20:54
@nabenabe0928 nabenabe0928 force-pushed the refactoring-base-dataset_splitting-functions_major-change branch from c7fd2d5 to a7e8a7f Compare March 18, 2021 23:10
@franchuterivera franchuterivera changed the base branch from refactor_development to development May 7, 2021 09:02
@nabenabe0928 nabenabe0928 force-pushed the refactoring-base-dataset_splitting-functions_major-change branch 2 times, most recently from 1e82b21 to d313a48 Compare May 10, 2021 03:30
Since the previous codes had the default shuffle = True and the indices
shuffle before splitting, the test cases for CV and Holdout did not match.
More specifically, when I bring back the followings, I could reproduce
the original outputs:
1. Bring back _get_indices in BaseDataset
2. Make the default value of self.shuffle in BaseDataset True
3. Input shuffle = True in KFold instead of using ShuffleSplit
These reproduce the original outputs.
Note that KFold(shuffle=True) and ShuffleSplit() are not identical
and even when we input the same random_state, the results do not reproduce.
@nabenabe0928 nabenabe0928 force-pushed the refactoring-base-dataset_splitting-functions_major-change branch from af8059b to 6ef981d Compare May 19, 2021 05:13
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
Additionally, HoldoutValTypes.<function> can be called directly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add an example to use it directly?



class CrossValFuncs():
# (shuffle, is_stratify) -> split_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also have documentation similar to HoldoutFuncs here?

Copy link
Contributor

@ravinkohli ravinkohli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks a lot for this PR. I have left a few suggestions. Also, could you state the reason for making this PR. What issues were there in the previous implementation? How does this PR solve them?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactoring Improvement of readability and abstract codes
Projects
Roadmap
In progress
Development

Successfully merging this pull request may close these issues.

None yet

2 participants