This module contains custom models, custom splitters, etc... summarization tasks.
 
What we're running with at the time this documentation was generated:
torch: 1.9.0+cu102
fastai: 2.5.2
transformers: 4.10.0
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti

Summarization

The objective of summarization is to generate a concise and accurate representation of a much larger body of text. For example, we may want to summarize an article in a single sentence.

Prepare the data

path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)
1000
cnndm_df.head(2)
article highlights ds_type
0 (CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el... John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in "century of knowledge" train
1 (CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will "hopefully bring some order" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ... NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says . train
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, 
                                                                  model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
text_gen_kwargs = {}
if (hf_arch in ['bart', 't5']):
    text_gen_kwargs = {**hf_config.task_specific_params['summarization'], **{'max_length': 30, 'min_length': 10}}

# not all "summarization" parameters are for the model.generate method ... remove them here
generate_func_args = list(inspect.signature(hf_model.generate).parameters.keys())
for k in text_gen_kwargs.copy():
    if k not in generate_func_args: del text_gen_kwargs[k]

if (hf_arch == 'mbart'):
    text_gen_kwargs['decoder_start_token_id'] = hf_tokenizer.get_vocab()["en_XX"]
tok_kwargs = {}
if (hf_arch == 'mbart'):
    tok_kwargs['src_lang'], tok_kwargs['tgt_lang'] = "en_XX", "en_XX"
before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model, 
                                                  max_length=256, max_target_length=130,
                                                  tok_kwargs=tok_kwargs, text_gen_kwargs=text_gen_kwargs)

blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
dls = dblock.dataloaders(cnndm_df, bs=2)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([2, 256]), torch.Size([2, 69]))
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 <s> (CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries that harbor the majority of the Earth's species and are therefore considered extremely biodiverse. From its coral reefs in the Caribbean Sea to its tropical jungles in Chiapas and the Yucatan peninsula and its deserts and prairies in the north, Mexico boasts an incredibly rich variety of flora and fauna. Some 574 out of 717 reptile species found in Mexico -- the most in any country -- can only be encountered within its borders. It is home to 502 types of mammals, 290 species of birds, 1,150 varieties of birds and 26,000 classifications of plants. Pronatura, a non-profit organization that works to promote conservation and sustainable development in Mexico, has selected six species which it says symbolize the problems faced by the</s> Mexico hosts to up to 10 percent of all known species on Earth.\nIt is home to 502 types of mammals, 290 bird species and 26,000 types of plants.\nHuman development and climate change is placing a big strain on its biodiversity.\nThe Golden Eagle is under threat in spite of being the country's national symbol.
1 <s> Some U.S. officials this year are expected to get smartphones capable of handling classified government documents over cellular networks, according to people involved in the project. The phones will run a modified version of Google's Android software, which is being developed as part of an initiative that spans multiple federal agencies and government contractors, these people said. The smartphones are first being deployed to U.S. soldiers, people familiar with the project said. Later, federal agencies are expected to get phones for sending and receiving government cables while away from their offices, sources said. Eventually, local governments and corporations could give workers phones with similar software. The Army has been testing touchscreen devices at U.S. bases for nearly two years, said Michael McCarthy, a director for the Army's Brigade Modernization Command, in a phone interview. About 40 phones were sent to fighters overseas a year ago, and the Army plans to ship 50 more phones and 75 tablets to soldiers abroad in March, he said. "We've had kind of an accelerated approval process," McCarthy said. "This is a hugely significant event." Currently, the United States doesn't allow government workers or soldiers to use smartphones for sending classified messages because the devices have not met security certifications. Officials have said they worry that hackers or</s> Government, military officials to get Android phones capable of sharing secret documents.\nThe phones will run a modified version of Google's Android software, sources say.\nContractor: Google "more cooperative" than Apple working with government on phones.

Train model

seq2seq_metrics = {
        'rouge': {
            'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
            'returns': ["rouge1", "rouge2", "rougeL"]
        },
        'bertscore': {
            'compute_kwargs': { 'lang': 'en' },
            'returns': ["precision", "recall", "f1"]
        }
    }
model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()

learn.freeze()
b = dls.one_batch()
preds = learn.model(b[0])

len(preds),preds['loss'].shape, preds['logits'].shape
(4, torch.Size([]), torch.Size([2, 79, 50264]))
len(b), len(b[0]), b[0]['input_ids'].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 256]), 2, torch.Size([2, 79]))
print(len(learn.opt.param_groups))
3
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
/home/wgilliam/miniconda3/envs/blurr/lib/python3.9/site-packages/fastai/callback/schedule.py:269: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "ro" (-> color='r'). The keyword argument will take precedence.
  ax.plot(val, idx, 'ro', label=nm, c=color)
SuggestedLRs(minimum=6.918309954926372e-05, steep=1.0964781722577754e-06, valley=3.630780702224001e-05, slide=1.4454397387453355e-05)
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss rouge1 rouge2 rougeL bertscore_precision bertscore_recall bertscore_f1 time
0 1.799327 1.663339 0.317453 0.153232 0.252488 0.893058 0.864607 0.878485 02:11
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (CNN Student News) -- March 23, 2010. Download PDF maps related to today's show:. • Haiti • China. Transcript. THIS IS A RUSH TRANSCRIPT. THIS COPY MAY NOT BE IN ITS FINAL FORM AND MAY BE UPDATED. CARL AZUZ, CNN STUDENT NEWS ANCHOR: Happy birthday, Roger Bannister -- first man to run the mile in less than four minutes. In more than twice that time, you'll be up to speed on today's headlines. I'm Carl Azuz. First Up: Health Care. AZUZ: First up, it's the biggest expansion of the United States he Find out what comes next after the passage of a health care reform bill.\nLearn about a proposal that would change how student loans are funded.\nFollow the steps that led to a showdown between China and Google.\nUse the Daily Discussion to help studen Check out the reaction to the U.S. House of Representatives passing a health care reform bill .\nLearn about some of
1 Washington (CNN)Almost immediately following the news of the first terrorist attacks that eventually killed 17 people across France, the global community united around a Twitter hashtag "Je suis Charlie" and just days later foreign leaders linked arms with their French counterparts to lead a historic million-person strong rally. Meanwhile, explosives strapped to a girl who appeared to be about 10-years-old detonated on Saturday, killing at least 20 people, in a country whose encounters with ter France and Nigeria experienced waves of terrorism during the first weeks of 2015.\nWhile the terror attacks in Paris sparked international unified outrage, reaction to Nigeria was more muted.\nSymbolism, politics and media all played a role in how Fra Terrorist attacks in Paris and Nigeria fomented unprecedented international reaction .\nThe response to the attacks in Nigeria paled in
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
outputs = learn.blurr_generate(test_article, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off into France .
The robbers spoke French and drove vehicles with French lRicense plates .
There were no serious injuries, although one guest was kicked in the head by one of the robbers .

=== Prediction 2 ===
 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
About 10 men armed with pistols and small machine guns raided the Grand Casino Basel .
The robbers spoke French and drove vehicles with French lRicense plates .
There were no serious injuries, although one guest was kicked in the head by one of the robbers .

=== Prediction 3 ===
 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off into France .
The robbers spoke French and drove vehicles with French lRicense plates .
There were about 600 people in the casino at the time of the robbery .

Inference

export_fname = 'summarize_export'
learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')
inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_article)
[' Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .\nAbout 10']

High-level API

BlearnerForSummarization

We also introduce a task specific Blearner that get you your DataBlock, DataLoaders, and BLearner in one line of code!

class BlearnerForSummarization[source]

BlearnerForSummarization(dls:DataLoaders, hf_model:PreTrainedModel, base_model_cb:HF_BaseModelCallback=HF_BaseModelCallback, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Blearner

Group together a model, some dls and a loss_func to handle training

Parameters:

  • dls : <class 'fastai.data.core.DataLoaders'>

  • hf_model : <class 'transformers.modeling_utils.PreTrainedModel'>

  • kwargs : <class 'inspect._empty'>

learn = BlearnerForSummarization.from_dataframe(cnndm_df, 'facebook/bart-large-cnn', 
                                                text_attr='article', summary_attr='highlights', 
                                                max_length=256, max_target_length=130,
                                                dblock_splitter=RandomSplitter(),
                                                dl_kwargs={'bs':2}).to_fp16()
learn.fit_one_cycle(1, lr_max=4e-5, cbs=[BlearnerForSummarization.get_metrics_cb()])
epoch train_loss valid_loss rouge1 rouge2 rougeL bertscore_precision bertscore_recall bertscore_f1 time
0 1.680803 1.656284 0.391191 0.175432 0.271317 0.876456 0.893137 0.884638 03:50
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (CNN Student News) -- March 23, 2010. Download PDF maps related to today's show:. • Haiti • China. Transcript. THIS IS A RUSH TRANSCRIPT. THIS COPY MAY NOT BE IN ITS FINAL FORM AND MAY BE UPDATED. CARL AZUZ, CNN STUDENT NEWS ANCHOR: Happy birthday, Roger Bannister -- first man to run the mile in less than four minutes. In more than twice that time, you'll be up to speed on today's headlines. I'm Carl Azuz. First Up: Health Care. AZUZ: First up, it's the biggest expansion of the United States he Find out what comes next after the passage of a health care reform bill.\nLearn about a proposal that would change how student loans are funded.\nFollow the steps that led to a showdown between China and Google.\nUse the Daily Discussion to help studen Find out why the U.S. House of Representatives passed a health care reform bill late Sunday night .\nCheck out some of the reaction to last night's vote on the health care issue .\nLearn about the impact of the earthquake in Haiti .\nMeet a man who is
1 (CNN Student News) -- Parents and Teachers: Watch with your students or record "Gary + Tony Have a Baby" when it airs on CNN on Thursday, June 24 at 8 p.m. ET. By recording the documentary, you agree that you will use the program for educational viewing purposes for a one-year period only. No other rights of any kind or nature whatsoever are granted, including, without limitation, any rights to sell, publish, distribute, post online or distribute in any other medium or forum, or use for any com "Gary + Tony Have a Baby" is a documentary that follows the journey of a gay couple as they attempt to become parents.\nParents and educators can use this guide to initiate discussion with students about the documentary. "Gary + Tony Have a Baby" is a documentary that follows the journey of a gay couple as they attempt to become parents .\nWe recommend that you preview this program and determine whether it is appropriate before showing it to students .\nBy recording t
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
outputs = learn.blurr_generate(test_article, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 Police: About 10 men armed with pistols and machine guns raided a casino in Switzerland .
They made off with several hundred thousand Swiss francs in the early hours of Sunday morning .
The robbers spoke French and drove vehicles with French lRicense plates .
There were no serious injuries, but one guest was kicked in the head by one of the robbers .

=== Prediction 2 ===
 Police: About 10 men armed with pistols and machine guns raided a casino in Switzerland .
They made off with several hundred thousand Swiss francs in the early hours of Sunday morning .
The robbers spoke French and drove vehicles with French lRicense plates .
There were no serious injuries, although one guest was kicked in the head by one of the robbers .

=== Prediction 3 ===
 Police: About 10 men armed with pistols and machine guns raided a casino in Switzerland .
They made off with several hundred thousand Swiss francs in the early hours of Sunday morning .
The robbers spoke French and drove vehicles with French lRicense plates .
There were about 600 people in the casino at the time of the robbery .

export_fname = 'summarize_export'

learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')

inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_article)
[' Police: About 10 men armed with pistols and machine guns raided a casino in Switzerland .\nThey made off with several hundred thousand Swiss francs in the early hours of Sunday morning .\nThe robbers spoke French and drove vehicles with French lRicense plates .\nThere were no serious injuries, but one guest was kicked in the head by one of the robbers .']

Tests

The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

arch tokenizer model_name result error
0 bart BartTokenizerFast BartForConditionalGeneration PASSED
1 led LEDTokenizerFast LEDForConditionalGeneration PASSED
2 mbart MBartTokenizerFast MBartForConditionalGeneration PASSED
3 mt5 T5TokenizerFast MT5ForConditionalGeneration PASSED
4 pegasus PegasusTokenizerFast PegasusForConditionalGeneration PASSED
5 t5 T5TokenizerFast T5ForConditionalGeneration PASSED

Summary

This module includes the fundamental bits to use Blurr for summarization tasks training and inference.