Skip to content

Commit fe9b2fe

Browse files
committed
Use standard gradient checkpointing for small sequence lengths
When max_seq_length < 512, the overhead of gradient offloading in gc="unsloth" mode is not worth it. Benchmarks on B200 show: | seq_len | gc=unsloth | gc=True | Difference | |---------|------------|----------|------------| | 256 | 6,803 t/s | 6,993 t/s| +2.8% | | 384 | 9,889 t/s | 9,963 t/s| +0.7% | | 512 | 13,151 t/s | 13,092 t/s| -0.4% | | 1024 | 26,662 t/s | 25,094 t/s| -5.9% | The crossover point is around seq_len 384-512. For sequences shorter than 512, we now automatically use standard gradient checkpointing instead of the custom offloading implementation.
1 parent 010775f commit fe9b2fe

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

unsloth/models/llama.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,9 +2641,15 @@ def get_peft_model(
26412641
transformers_set_seed(random_state)
26422642

26432643
if use_gradient_checkpointing == "unsloth":
2644-
patch_unsloth_smart_gradient_checkpointing(
2645-
dtype = model.get_input_embeddings().weight.dtype
2646-
)
2644+
# Gradient offloading overhead is not worth it for small sequences.
2645+
# Benchmarks show crossover point is around seq_len 384-512.
2646+
# For seq < 512, standard gradient checkpointing is faster.
2647+
if hasattr(model, "max_seq_length") and model.max_seq_length < 512:
2648+
use_gradient_checkpointing = True
2649+
else:
2650+
patch_unsloth_smart_gradient_checkpointing(
2651+
dtype = model.get_input_embeddings().weight.dtype
2652+
)
26472653

26482654
if type(r) is not int:
26492655
raise TypeError(f"Unsloth: Rank of {str(r)} must be an integer.")

0 commit comments

Comments
 (0)