Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于Packing和 直接Batch的loss区别? #3

Open
BitVoyage opened this issue Feb 18, 2024 · 3 comments
Open

关于Packing和 直接Batch的loss区别? #3

BitVoyage opened this issue Feb 18, 2024 · 3 comments
Labels
question Further information is requested

Comments

@BitVoyage
Copy link

论文中指出Packing Loss和直接Batch Loss不一致,是基于这个公式:
image
即:以样本为粒度,算loss 先在样本内平均,再batch内平均,两步走。

基于我的认知,SFT训练中一般是以Token为粒度算最终的loss的,即 "target token loss 总和 / target token 总数",并非样本粒度。

我看了下你的代码实现,即modeling_llama.py文件中按直接Batch算,loss是 从 batch*seq 直接Flat成一个seq,还是直接以token为粒度计算的loss,并非样本粒度(即先在seq 求平均,再在batch求平均)
image

有两个问题讨论:

  1. SFT中loss 最后一步的平均, 究竟应该以Token为粒度 还是以样本为粒度?
  2. 如果以Token为粒度,我认为Packing和非Packing是等价的
@bys0318
Copy link
Member

bys0318 commented Feb 20, 2024

Good question! SFT中算loss通常来讲都是样本内作token-level mean,样本间作sequence-level mean,也就是等式(2)的计算方式。如果不同样本间作token-level mean,则会使target token数量多的样本更受重视(相当于被upsample),从而引入不同样本间的不平衡。如果按照你说的"target token loss 总和 / target token 总数"的总loss计算方式,只需要将代码中对每个样本原本作mean得到的token-level loss替换为作sum得到的target token loss总和即可。

@BitVoyage
Copy link
Author

感谢回复。我有个疑问就是我查看代码中并没有发现 先样本内mean再样本间mean的内容,我是看你的README中 llama用的sorted batching,然后modeling_llama中loss是没有平均,flat成token loss序列,然后送入transformer.Trainer中训练,是transformer.Trainer中做了两次求mean吗?

@bys0318
Copy link
Member

bys0318 commented Feb 20, 2024

在packing训练使用loss weighting时self.pack_loss会被置为True,请参考if self.pack_loss:下的代码,我们首先对一个样本内每个target token上的loss乘以weight并求sum,然后多gpu上的不同样本的loss在transformer.Trainer中被作mean。

@bys0318 bys0318 added the question Further information is requested label Feb 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants