Commit fe9b2fe
committed
Use standard gradient checkpointing for small sequence lengths
When max_seq_length < 512, the overhead of gradient offloading in
gc="unsloth" mode is not worth it. Benchmarks on B200 show:
| seq_len | gc=unsloth | gc=True | Difference |
|---------|------------|----------|------------|
| 256 | 6,803 t/s | 6,993 t/s| +2.8% |
| 384 | 9,889 t/s | 9,963 t/s| +0.7% |
| 512 | 13,151 t/s | 13,092 t/s| -0.4% |
| 1024 | 26,662 t/s | 25,094 t/s| -5.9% |
The crossover point is around seq_len 384-512. For sequences shorter
than 512, we now automatically use standard gradient checkpointing
instead of the custom offloading implementation.1 parent 010775f commit fe9b2fe
1 file changed
+9
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2641 | 2641 | | |
2642 | 2642 | | |
2643 | 2643 | | |
2644 | | - | |
2645 | | - | |
2646 | | - | |
| 2644 | + | |
| 2645 | + | |
| 2646 | + | |
| 2647 | + | |
| 2648 | + | |
| 2649 | + | |
| 2650 | + | |
| 2651 | + | |
| 2652 | + | |
2647 | 2653 | | |
2648 | 2654 | | |
2649 | 2655 | | |
| |||
0 commit comments