torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')
cnndm_df.head(2)
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name,
model_cls=BartForConditionalGeneration)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
max_length=256, max_target_length=130)
blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)
dblock = DataBlock(blocks=blocks,
get_x=ColReader('article'),
get_y=ColReader('highlights'),
splitter=RandomSplitter())
dls = dblock.dataloaders(cnndm_df, bs=2)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
dls.show_batch(dataloaders=dls, max_n=2)
Training
Here we create a Seq2Seq specific subclass of HF_BaseModelCallback
in order to include custom, Seq2Seq specific, metrics, and also handle the pre-calculated loss during training
seq2seq_metrics
- {'rouge': { 'compute_args': {'return_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True}, 'returns':["rouge1", "rouge2", "rougeL"]}
- {'bert_score': { 'returns': ["precision", "recall", "f1"] }
- {'bleu': { 'returns': "bleu" }
- {'bleurt': { 'returns': "scores" }
- {'meteor': { 'returns': "meteor" }
- {'sacrebleu': { 'returns': "score" }
We add a custom param splitter to give us a bit more depth in applying discriminative learning rates for Seq2Seq tasks.
seq2seq_metrics = {
'rouge': {
'compute_kwargs': {
'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True
},
'returns': ["rouge1", "rouge2", "rougeL"]
},
'bertscore': {
'compute_kwargs': { 'lang': 'en' },
'returns': ["precision", "recall", "f1"]
},
'bleu': { 'returns': "bleu" },
'meteor': { 'returns': "meteor" },
'sacrebleu': { 'returns': "score" }
}
model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]
learn = Learner(dls,
model,
opt_func=partial(Adam),
loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
cbs=learn_cbs,
splitter=partial(seq2seq_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()
learn.create_opt()
learn.freeze()
b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds['loss'].shape, preds['logits'].shape
b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds['loss'].shape, preds['logits'].shape
print(len(learn.opt.param_groups))
learn.lr_find(suggestions=True)
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said.
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat
her, and took off for the French border. The other gunmen followed into France, which is only about 100
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery.
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities,
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan
contributed to this report.
"""
res = learn.blurr_predict(test_article)
print(hf_tokenizer.decode(res[0][0][0][:20]))
That doesn't look much like a human-generated text. Let's use huggingface's PreTrainedModel.generate
method to create something more human-like.
b = dls.valid.one_batch()
b_before_batch_tfm = get_blurr_tfm(dls.before_batch)
b_hf_tokenizer = b_before_batch_tfm.hf_tokenizer
b_ignore_token_id = b_before_batch_tfm.ignore_token_id
test_input_ids = b[0]['input_ids'][0].unsqueeze(0).to(learn.model.hf_model.device)
test_trg_ids = b[1][0].unsqueeze(0).to(learn.model.hf_model.device)
test_trg_ids = [ trg[trg != b_ignore_token_id] for trg in test_trg_ids ]
gen_text = learn.model.hf_model.generate(test_input_ids, num_beams=4, max_length=130, min_length=30)
print('=== Target ===')
print(f'{b_hf_tokenizer.decode(test_trg_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)}\n')
print('=== Prediction ===')
print(b_hf_tokenizer.decode(gen_text[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
We'll add a blurr_generate
method to Learner
that uses huggingface's PreTrainedModel.generate
to create our predictions. For the full list of arguments you can pass in see here. You can also check out their "How To Generate" notebook for more information about how it all works.
outputs = learn.blurr_generate(test_article, num_return_sequences=3)
for idx, o in enumerate(outputs):
print(f'=== Prediction {idx+1} ===\n{o}\n')
Much nicer!!! Now, we can update our @typedispatched show_results
to use this new method.
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
export_fname = 'summarize_export'
learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')
inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_article)