Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add language keywords for syntax highlighting #36

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ For JAX/Flax user, take a look at a simple train function [here](https://github.

### Initialization
The class's `__init__` method defines the cache and has several functional parameters `*_fn` for easy adjust of model behaviors. Alternatively you can also sub-class GradCache.
```
```python
grad_cache.GradCache(
models: List[nn.Module],
chunk_sizes: Union[int, List[int]],
Expand Down Expand Up @@ -66,7 +66,7 @@ grad_cache.GradCache(
### Cache Gradient Step
To run a cached gradient computatoin step, call `cache_step` function,

```
```python
cache_step(
*model_inputs,
no_sync_except_last: bool = False,
Expand Down Expand Up @@ -101,7 +101,7 @@ To run with them, `split_input_fn` should be specified during cache initializati
## Example Usage with Huggingface Transformers
### Learning a Bi-encoder
Say we want to learn a embedding space of labels and text. Consider the following four pairs. (In practice, you will have many more and much longer text entries.)
```
```python
labels = ['fruit', 'meat', 'school', 'company']
texts = [
'this is an apple',
Expand All @@ -112,14 +112,14 @@ texts = [
```

Initialize our encoder models,
```
```python
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()
```
Initialize the GradCache object,
```
```python
from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss

Expand All @@ -134,24 +134,24 @@ gc = GradCache(
Here we use the **get_rep_fn** argument to specify a function that takes generic Huggingface model output and return the actual representation tensor.

Create model input,
```
```python
xx = tokenizer(tt, return_tensors='pt', padding=True)
yy = tokenizer(tt2, return_tensors='pt', padding=True)
```
Run a cache step,
```
```python
gc(xx, yy, reduction='mean')
```
Here we use `reduction='mean'` as a **loss_kwargs** to control loss behavior. With a defined `optimizer`, the full gradient update can be done as,
```
```python
optimizer.zero_grad()
gc(xx, yy, reduction='mean')
optimizer.step()
```

### Use Tied Encoder?
This is naturally handled by the (magic of) dynamic graph. You pass shallow copies of the same encoder model to the GradCache init method.
```
```python
tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda()
gc = GradCache(
models=[tied_encoder , tied_encoder],
Expand All @@ -163,19 +163,19 @@ gc = GradCache(
Under the hood, distinct hooks will be registered to make correct gradient computation.
### Distributed Training with Multiple GPUs?
We expect cross process communication of representations to be handled by the `loss_fn`.
```
```python
from grad_cache.loss import DistributedContrastiveLoss
loss_fn_dist = DistributedContrastiveLoss()
```
Properly wrap the the encoder models for gradient reduction,
```
```python
encoder1_ddp = DistributedDataParallel(
encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
encoder2_ddp = DistributedDataParallel(
encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
```
You can initialize the cache use the distributed loss and the DDP models,
```
```python
gc = GradCache(
models=[encoder1_ddp, encoder2_ddp],
chunk_sizes=2,
Expand All @@ -184,23 +184,23 @@ gc = GradCache(
)
```
Run a cache step,
```
```python
gc(xx, yy, no_sync_except_last=True, reduction='mean')
```
Set `no_sync_except_last=True` to avoid unnecessary gradient reduction.

## Functional Approach
### Decorators
If you are developing a new project, we recommend also checking out the decorators we have provided to create higher order functions for cache.
```
```python
grad_cache.functional.cached(func: Callable[..., Tensor])
```
A decorator that takes a model call function into a cached compatible version.

**func** - A function that calls the model and return representation tensor.

**Return** - A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor.
```
```python
grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor])
```
A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor on the 0th dimension. This can come in handy dealing with results of representation tensors from multiple cached forward.
Expand All @@ -209,7 +209,7 @@ A decorator that concatenates positional and keyword arguments of type List[Tens

**Return** - Decorated loss function for cached results.

```
```python
grad_cache.functional.gather_input_tensor(func: Callable[..., Tensor], axis=0)
```
A decorator that all-gather positional and keyword arguments of type Tensor and concatenate them on axis. Intended to be used to create distributed contrastive learning loss.
Expand All @@ -219,7 +219,7 @@ A decorator that all-gather positional and keyword arguments of type Tensor and
**Return** - Decorated loss function for distributed training.
### Usage
The functional decorators are particular useful if your data loader is emitting small batches, from which you can construct the big batch. Say you also want to do automatic mixed precision, we first define the model call function and loss function,
```
```python
from grad_cache.functional import cached, cat_input_tensor

import torch
Expand All @@ -240,7 +240,7 @@ def contrastive_loss(x, y):
```
Say you have a DataLoader `loader` emitting small batches of tuple `(xx, yy)` of size (M * N) and that you want to train by aggregating 16 small batches to get a batch of (16M * 16N),

```
```python
cache_x = []
cache_y = []
closures_x = []
Expand Down Expand Up @@ -278,13 +278,13 @@ for step, sub_batch in enumerate(loader):
Running distributed multi-process training requires: 1) (all-)gather representations across devices and 2) (all-reduce) gradients across devices. Both steps will happen **outside** the cached decorated funtions.

The latter is easy to achieve by wrapping encoders, e.g. a `bert`, in `DistributedDataParallel`.
```
```python
bert = DistributedDataParallel(
bert, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
```

The former requires extra distributed ops in the loss function, which should be done according the original loss definition. For example,
```
```python
from torch import distributed as dist
from grad_cache.functional import cat_input_tensor, gather_input_tensor

Expand Down