We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 86ac601 commit 986c500Copy full SHA for 986c500
qlora.py
@@ -267,6 +267,7 @@ def get_accelerate_model(args, checkpoint_dir):
267
compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
268
model = AutoModelForCausalLM.from_pretrained(
269
args.model_name_or_path,
270
+ cache_dir=args.cache_dir,
271
load_in_4bit=args.bits == 4,
272
load_in_8bit=args.bits == 8,
273
device_map='auto',
0 commit comments