Downloading builder script: 0%| | 0.00/1.72k [00:00<?, ?B/s]Downloading builder script: 4.50kB [00:00, 4.65MB/s]
Downloading extra modules: 0%| | 0.00/1.12k [00:00<?, ?B/s]Downloading extra modules: 3.31kB [00:00, 3.27MB/s]
callbacks
Callbacks used by the BLURR library.
Gradient Checkpointing
CheckpointingNotSupported
CheckpointingNotSupported (msg='Model does not support gradient checkpointing.')
Common base class for all non-exit exceptions.
GradientCheckpointing
GradientCheckpointing (after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None)
A fastai callback to enable gradient checkpointing for compatible HuggingFace models.
We’ll use a minified version of the IMDB dataset for testing
= untar_data(URLs.IMDB_SAMPLE)
path = Path("models")
model_path = pd.read_csv(path / "texts.csv") imdb_df
Let’s look at memory consumption without GradientCheckpointing
= 2
nvidia_smi_idx
def gpu_memory(device_idx=nvidia_smi_idx):
return GPU.getGPUs()[device_idx].memoryUsed
= BlearnerForSequenceClassification.from_data(imdb_df, "roberta-large", dl_kwargs={"bs": 4})
learn
1, lr_max=1e-3)
learn.fit_one_cycle(
= gpu_memory()
base_mem print(f"{base_mem} MBs used.")
reset_memory(learn)
epoch | train_loss | valid_loss | f1_score | accuracy | time |
---|---|---|---|---|---|
0 | 0.341047 | 0.237419 | 0.918033 | 0.925000 | 00:57 |
9499.0 MBs used.
Let’s look at memory consumption with GradientCheckpointing
= BlearnerForSequenceClassification.from_data(imdb_df, "roberta-large", dl_kwargs={"bs": 4})
learn
1, lr_max=1e-3, cbs=[GradientCheckpointing()])
learn.fit_one_cycle(
= gpu_memory()
check_mem print(f"{check_mem} MBs used.")
> check_mem, True)
test_eq(base_mem reset_memory(learn)
epoch | train_loss | valid_loss | f1_score | accuracy | time |
---|---|---|---|---|---|
0 | 0.299704 | 0.222900 | 0.920455 | 0.930000 | 01:22 |
4297.0 MBs used.