Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for 4bit (nf4) and 8bit bitsandbytes quantization (3/3) #151

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
perf(fp16): reduce expected extra required iterations to 1.
Start search with minimal clipping value found through testing (2^16 - 3*32). This value was sufficient for all tested inputs. Further analysis still required to guarantee that it will always be  sufficient in all cases.
Rypo committed Dec 17, 2024
commit 6ce30f2c7fcf95dd6514b5e7062b27b2ac644d47
4 changes: 3 additions & 1 deletion OmniGen/pipeline.py
Original file line number Diff line number Diff line change
@@ -318,7 +318,9 @@ def __call__(

scheduler = OmniGenScheduler(num_steps=num_inference_steps)
if latents.dtype == torch.float16:
# fp16 overflows at ±2^16-32, but the actual clamp value may have to be lower to maintain decoder layer stability
# Continue to monitor. If _clip_val never changes, can remove scheduler autoset func and just hardcode clip val here.
#self.model.llm.set_clip_val(2**16-32 - 2*32) # hardcode clip val
# dry run the inputs, adjusting the clip bounds as necessary
scheduler._fp16_clip_autoset(self.model.llm, latents, func, model_kwargs)
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache)
samples = samples.chunk((1+num_cfg), dim=0)[0]
31 changes: 22 additions & 9 deletions OmniGen/scheduler.py
Original file line number Diff line number Diff line change
@@ -128,23 +128,36 @@ def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
@torch.no_grad()
def _fp16_clip_autoset(self, model_llm, z, func, model_kwargs):
'''Recursively search for a minimal clipping value for fp16 stability'''
fp16_max_repr = torch.finfo(torch.float16).max # fp16 max representable: ±2^16-32
timesteps = torch.full(size=(len(z), ), fill_value=self.sigma[0], device=z.device)
_nan_expon = model_kwargs.pop('_nan_expon', None)
if _nan_expon is not None:
clip_val = 2**16 - 2**_nan_expon # fp16 overflows after ±2^16-32
model_llm.set_clip_val(clip_val)
_buff_expon = model_kwargs.pop('_buff_expon', None) # temp local recursion var

if _buff_expon is None:
# fp16 overflows at ±2^16-16 with largest repr being ±2^16-32. repr vals occur at intervals of 32 for nums > 2^15.
# Prelim tests show an additional buffer of at least 2 repr values is needed for stability; why is presently unclear.
# If this continues to hold true, this function can be deleted and replaced with 1 line in pipeline.
clip_val = fp16_max_repr - 2*32 # = 2**6 = (-2,+2 buffer vals)
if model_llm._clip_val is None or model_llm._clip_val > clip_val:
model_llm.set_clip_val(clip_val)
logger.debug(f'set initial clamp: (+-){clip_val} ...')
else:
clip_val = fp16_max_repr - 2**_buff_expon
model_llm.set_clip_val(clip_val) # clamp (-clip_val, +clip_val)

try:
_model_kwargs = copy.deepcopy(model_kwargs)
_model_kwargs['use_kv_cache']=False # no cache while searching
_, _ = func(z.clone(), timesteps, past_key_values=None, **_model_kwargs)
except OverflowError:
if _nan_expon is None:
if _buff_expon is None:
_buff_expon = 6 # start at 2**(6 + 1) (-4,+4 buffer vals)
logger.info('FP16 overflow, searching for clamp bounds...')
_nan_expon = 5 # start at 2**5

if _nan_expon < 15: # stop at 2**15
model_kwargs['_nan_expon'] = _nan_expon+1

if _buff_expon < 15: # stop at 2**15 (-1024,+1024 buffer vals)
_buff_expon += 1
# each iter, double the representable value buffer capacity for both min and max
model_kwargs['_buff_expon'] = _buff_expon
logger.debug(f'trying clamp: (+-){fp16_max_repr - 2**(_buff_expon)} ...')
return self._fp16_clip_autoset(model_llm, z, func, model_kwargs)
raise OverflowError('Numerical overflow, unable to find suitable clipping bounds.')