This module contains custom models, custom splitters, etc... summarization tasks.
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)
cnndm_df.head(2)
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name,
model_cls=BartForConditionalGeneration)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
text_gen_kwargs = {}
if (hf_arch in ['bart', 't5']):
text_gen_kwargs = {**hf_config.task_specific_params['summarization'], **{'max_length': 30, 'min_length': 10}}
# not all "summarization" parameters are for the model.generate method ... remove them here
generate_func_args = list(inspect.signature(hf_model.generate).parameters.keys())
for k in text_gen_kwargs.copy():
if k not in generate_func_args: del text_gen_kwargs[k]
if (hf_arch == 'mbart'):
text_gen_kwargs['decoder_start_token_id'] = hf_tokenizer.get_vocab()["en_XX"]
tok_kwargs = {}
if (hf_arch == 'mbart'):
tok_kwargs['src_lang'], tok_kwargs['tgt_lang'] = "en_XX", "en_XX"
before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
max_length=256, max_target_length=130,
tok_kwargs=tok_kwargs, text_gen_kwargs=text_gen_kwargs)
blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)
dblock = DataBlock(blocks=blocks,
get_x=ColReader('article'),
get_y=ColReader('highlights'),
splitter=RandomSplitter())
dls = dblock.dataloaders(cnndm_df, bs=2)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
dls.show_batch(dataloaders=dls, max_n=2)
seq2seq_metrics = {
'rouge': {
'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
'returns': ["rouge1", "rouge2", "rougeL"]
},
'bertscore': {
'compute_kwargs': { 'lang': 'en' },
'returns': ["precision", "recall", "f1"]
}
}
model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]
learn = Learner(dls,
model,
opt_func=partial(Adam),
loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
cbs=learn_cbs,
splitter=partial(seq2seq_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()
learn.create_opt()
learn.freeze()
b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds['loss'].shape, preds['logits'].shape
len(b), len(b[0]), b[0]['input_ids'].shape, len(b[1]), b[1].shape
print(len(learn.opt.param_groups))
learn.lr_find(suggestions=True)
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said.
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat
her, and took off for the French border. The other gunmen followed into France, which is only about 100
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery.
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities,
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan
contributed to this report.
"""
outputs = learn.blurr_generate(test_article, num_return_sequences=3)
for idx, o in enumerate(outputs):
print(f'=== Prediction {idx+1} ===\n{o}\n')
export_fname = 'summarize_export'
learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')
inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_article)
Tests
The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.
Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)
try: del learn; torch.cuda.empty_cache()
except: pass
[ model_type for model_type in BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration')
if (not model_type.__name__.startswith('TF')) ]
pretrained_model_names = [
'facebook/bart-base',
#'facebook/blenderbot_small-90M',
'allenai/led-base-16384',
'sshleifer/tiny-mbart',
'google/mt5-small',
'sshleifer/distill-pegasus-cnn-16-4',
't5-small',
#'microsoft/prophetnet-large-uncased',
#'microsoft/xprophetnet-large-wiki100-cased', # XLMProphetNet
]
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')
#hide_output
task = HF_TASKS_AUTO.Seq2SeqLM
bsz = 2
inp_seq_sz = 64; trg_seq_sz = 40
test_results = []
for model_name in pretrained_model_names:
error=None
print(f'=== {model_name} ===\n')
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, task=task)
print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\nmodel:\t\t{type(hf_model).__name__}\n')
# 1. build your DataBlock
text_gen_kwargs = {}
if (hf_arch in ['bart', 't5']):
text_gen_kwargs = {**hf_config.task_specific_params['summarization'], **{'max_length': 30, 'min_length': 10}}
# not all "summarization" parameters are for the model.generate method ... remove them here
generate_func_args = list(inspect.signature(hf_model.generate).parameters.keys())
for k in text_gen_kwargs.copy():
if k not in generate_func_args: del text_gen_kwargs[k]
if (hf_arch == 'mbart'):
text_gen_kwargs['decoder_start_token_id'] = hf_tokenizer.get_vocab()["en_XX"]
tok_kwargs = {}
if (hf_arch == 'mbart'):
tok_kwargs['src_lang'], tok_kwargs['tgt_lang'] = "en_XX", "en_XX"
def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp
before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
padding='max_length',
max_length=inp_seq_sz,
max_target_length=trg_seq_sz,
tok_kwargs=tok_kwargs, text_gen_kwargs=text_gen_kwargs)
blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)
dblock = DataBlock(blocks=blocks,
get_x=Pipeline([ColReader('article'), add_t5_prefix]),
get_y=ColReader('highlights'),
splitter=RandomSplitter())
dls = dblock.dataloaders(cnndm_df, bs=bsz)
b = dls.one_batch()
# 2. build your Learner
seq2seq_metrics = {
'rouge': {
'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
'returns': ["rouge1", "rouge2", "rougeL"]
}
}
model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [
ShortEpochCallback(0.05, short_valid=True),
HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)
]
learn = Learner(dls,
model,
opt_func=ranger,
loss_func=HF_PreCalculatedLoss(),
cbs=learn_cbs,
splitter=partial(seq2seq_splitter, arch=hf_arch)).to_fp16()
learn.create_opt()
learn.freeze()
# 3. Run your tests
try:
print('*** TESTING DataLoaders ***\n')
test_eq(len(b), 2)
test_eq(len(b[0]['input_ids']), bsz)
test_eq(b[0]['input_ids'].shape, torch.Size([bsz, inp_seq_sz]))
test_eq(len(b[1]), bsz)
# print('*** TESTING One pass through the model ***')
# preds = learn.model(b[0])
# test_eq(preds[1].shape[0], bsz)
# test_eq(preds[1].shape[2], hf_config.vocab_size)
print('*** TESTING Training/Results ***')
learn.fit_one_cycle(1, lr_max=1e-3, cbs=fit_cbs)
test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'PASSED', ''))
learn.show_results(learner=learn, max_n=2, input_trunc_at=500, target_trunc_at=250)
except Exception as err:
test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'FAILED', err))
finally:
# cleanup
del learn; torch.cuda.empty_cache()