Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fun solution to loop replacement (makemore_part4_backprop.ipynb) #46

Open
jchwenger opened this issue Mar 16, 2024 · 1 comment
Open

Comments

@jchwenger
Copy link

Hi Andrej, hi everyone,

First of all, let me add my voice to the chorus: such awesome lectures, very grateful for them, I recommend them around me as soon as I have the opportunity!

At one point in the backprop lecture, you mention that there might be slicker way to update the last gradient tensor, dC, instead of the Python loop you used. This tickled my curiosity, so I tinkered, and here's the solution I came up with, maybe others have found even better ways! (Although, arguably, if you're not into Torch nerdiness the threat to time management/peace of mind when basking in advanced indexing might not be lead to a great trade-off with the slow but straightforward loop! : >)

So, instead of:

dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
  for j in range(Xb.shape[1]):
    ix = Xb[k,j]
    dC[ix] += demb[k,j]

It is possible to do:

# arange        -> unsqueeze  -> tile         -> flatten
# [ 0,1,...32 ] -> [[0],      -> [[0,0,0],    -> [0,0,0,1,1,1,...,31,31,31] # batch_size * block_size times
#                   [1],          [1,1,1], 
#                   ...           ...  
#                   [31]]         [31,31,31]]
rows_xi = torch.tile(torch.arange(0, Xb.shape[0]).unsqueeze(1), (1,3)).flatten()

# [0,1,2] -> [[0,1,2],[0,1,2],...,[0,1,2]] # block_size * batch_size times
cols_xi = torch.tile(torch.arange(0, Xb.shape[1]), (Xb.shape[0],))

emb_xi = Xb[rows_xi, cols_xi] # block_size * batch_size indices to retrieve rows

dC1 = torch.zeros_like(C)

dC1.index_put_((emb_xi,), demb[rows_xi, cols_xi], accumulate=True)

A torch.allclose(dC1, dC) yields True on my end.

I'm indebted to the all-answering @ptrblck for that .index_put_(... accumulate=True) reference!

Have a great day!

@junqi-lu
Copy link

Thanks to chatgpt, we get a faster way to replace the loop.

import time
loop_times = 1_000
start = time.time()
for _ in range(loop_times):
    dC = torch.zeros_like(C)
    for i in range(demb.shape[0]):
        for j in range(demb.shape[1]):
            dC[Xb[i, j]] += demb[i,j]
print(time.time() - start)  # 0.7680590152740479

start = time.time()
for _ in range(loop_times):
    rows_xi = torch.tile(torch.arange(0, Xb.shape[0]).unsqueeze(1), (1,3)).flatten()
    # [0,1,2] -> [[0,1,2],[0,1,2],...,[0,1,2]] # block_size * batch_size times
    cols_xi = torch.tile(torch.arange(0, Xb.shape[1]), (Xb.shape[0],))
    emb_xi = Xb[rows_xi, cols_xi] # block_size * batch_size indices to retrieve rows
    dC1 = torch.zeros_like(C)
    dC1.index_put_((emb_xi,), demb[rows_xi, cols_xi], accumulate=True)
print(time.time() - start)  # 0.022248029708862305 

start = time.time()
for _ in range(loop_times):
    dC = torch.zeros_like(C)
    Xb_flat = Xb.view(-1)
    demb_flat = demb.view(-1, demb.size(2)) 
    dC.scatter_add_(0, Xb_flat.unsqueeze(1).expand_as(demb_flat), demb_flat)
print(time.time() - start)  # 0.009483575820922852 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants