Downloading builder script: 0%| | 0.00/1.72k [00:00<?, ?B/s]Downloading builder script: 4.50kB [00:00, 4.65MB/s]
Downloading extra modules: 0%| | 0.00/1.12k [00:00<?, ?B/s]Downloading extra modules: 3.31kB [00:00, 3.27MB/s]
callbacks
Callbacks used by the BLURR library.
Gradient Checkpointing
CheckpointingNotSupported
CheckpointingNotSupported (msg='Model does not support gradient checkpointing.')
Common base class for all non-exit exceptions.
GradientCheckpointing
GradientCheckpointing (after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None)
A fastai callback to enable gradient checkpointing for compatible HuggingFace models.
We’ll use a minified version of the IMDB dataset for testing
path = untar_data(URLs.IMDB_SAMPLE)
model_path = Path("models")
imdb_df = pd.read_csv(path / "texts.csv")Let’s look at memory consumption without GradientCheckpointing
nvidia_smi_idx = 2
def gpu_memory(device_idx=nvidia_smi_idx):
return GPU.getGPUs()[device_idx].memoryUsedlearn = BlearnerForSequenceClassification.from_data(imdb_df, "roberta-large", dl_kwargs={"bs": 4})
learn.fit_one_cycle(1, lr_max=1e-3)
base_mem = gpu_memory()
print(f"{base_mem} MBs used.")
reset_memory(learn)| epoch | train_loss | valid_loss | f1_score | accuracy | time |
|---|---|---|---|---|---|
| 0 | 0.341047 | 0.237419 | 0.918033 | 0.925000 | 00:57 |
9499.0 MBs used.
Let’s look at memory consumption with GradientCheckpointing
learn = BlearnerForSequenceClassification.from_data(imdb_df, "roberta-large", dl_kwargs={"bs": 4})
learn.fit_one_cycle(1, lr_max=1e-3, cbs=[GradientCheckpointing()])
check_mem = gpu_memory()
print(f"{check_mem} MBs used.")
test_eq(base_mem > check_mem, True)
reset_memory(learn)| epoch | train_loss | valid_loss | f1_score | accuracy | time |
|---|---|---|---|---|---|
| 0 | 0.299704 | 0.222900 | 0.920455 | 0.930000 | 01:22 |
4297.0 MBs used.