This module contains custom models, custom splitters, etc... summarization tasks.
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
 
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti

Summarization

The objective of summarization is to generate a concise and accurate representation of a much larger body of text. For example, we may want to summarize an article in a single sentence.

Prepare the data

path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)
1000
cnndm_df.head(2)
article highlights ds_type
0 (CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el... John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in "century of knowledge" train
1 (CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will "hopefully bring some order" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ... NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says . train
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
text_gen_kwargs = {}
if (hf_arch in ['bart', 't5']):
    text_gen_kwargs = {**hf_config.task_specific_params['summarization'], **{'max_length': 30, 'min_length': 10}}

# not all "summarization" parameters are for the model.generate method ... remove them here
generate_func_args = list(inspect.signature(hf_model.generate).parameters.keys())
for k in text_gen_kwargs.copy():
    if k not in generate_func_args: del text_gen_kwargs[k]

if (hf_arch == 'mbart'):
    text_gen_kwargs['decoder_start_token_id'] = hf_tokenizer.get_vocab()["en_XX"]
tok_kwargs = {}
if (hf_arch == 'mbart'):
    tok_kwargs['src_lang'], tok_kwargs['tgt_lang'] = "en_XX", "en_XX"
before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model, 
                                                  max_length=256, max_target_length=130,
                                                  tok_kwargs=tok_kwargs, text_gen_kwargs=text_gen_kwargs)

blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
dls = dblock.dataloaders(cnndm_df, bs=2)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([2, 256]), torch.Size([2, 68]))
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 Dan Condon believes in recycling. Just not when it comes to his hotel towels. Condon composts when he's at home in Boulder, Colorado. He eats local, organic and fair-trade food and drives a Honda CR-Z hybrid sports car. You might call him green. Except he's not so green when he travels for his work at an education nonprofit and stays in a hotel, which happens about 10 weeks per year. There, he uses a new towel every day. And don't try to bribe him with a drink or dessert coupon to get him to reuse the same one. "I could care less about rewards for environmentally conscious behavior unless it's miles," Condon wrote in an e-mail. If hotels can't convince a hybrid-driving recycling enthusiast like Condon to go green while traveling, how can they possibly convince everyone else? 9 glamorous movie-star hotels. That's the problem of hotels trying to "green" your hotel stay. After guests have paid a pretty penny for a night at the inn, even the most environmental guests may want to treat themselves to fresh towels every day and those little bottles of sweet-smelling shampoo. Despite the fact that most people describe themselves in surveys as environmentally conscious and as preferring green products, Hotel guests who "go green" are happier with their stay.\nIncreasing water and energy costs are pushing hotels to cut costs wherever they can.\nMany hotels find that guests don't mind using the same towels and sheets every night.\nTripAdvisor will be adding a green label for hotels listed on its site.
1 (CNN Student News) -- March 23, 2010. Download PDF maps related to today's show:. • Haiti • China. Transcript. THIS IS A RUSH TRANSCRIPT. THIS COPY MAY NOT BE IN ITS FINAL FORM AND MAY BE UPDATED. CARL AZUZ, CNN STUDENT NEWS ANCHOR: Happy birthday, Roger Bannister -- first man to run the mile in less than four minutes. In more than twice that time, you'll be up to speed on today's headlines. I'm Carl Azuz. First Up: Health Care. AZUZ: First up, it's the biggest expansion of the United States health care system in more than forty years. And by a vote of 219-212, the U.S. House of Representatives passed a health care reform bill late Sunday night. This is the same bill that the Senate passed last December. This means that when President Obama signs it, it's law. The House also passed a set of changes to the Senate bill. We're gonna get back to that in just a second. But first, you know this health care issue has been controversial. We want you to check out some of the reaction to last night's vote. REP. NANCY Find out what comes next after the passage of a health care reform bill.\nLearn about a proposal that would change how student loans are funded.\nFollow the steps that led to a showdown between China and Google.\nUse the Daily Discussion to help students understand today's featured news stories.

Train model

seq2seq_metrics = {
        'rouge': {
            'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
            'returns': ["rouge1", "rouge2", "rougeL"]
        },
        'bertscore': {
            'compute_kwargs': { 'lang': 'en' },
            'returns': ["precision", "recall", "f1"]
        }
    }
model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()

learn.create_opt() 
learn.freeze()
 
b = dls.one_batch()
preds = learn.model(b[0])

len(preds),preds['loss'].shape, preds['logits'].shape
(4, torch.Size([]), torch.Size([2, 78, 50264]))
len(b), len(b[0]), b[0]['input_ids'].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 256]), 2, torch.Size([2, 78]))
print(len(learn.opt.param_groups))
3
learn.lr_find(suggestions=True)
SuggestedLRs(lr_min=0.00010000000474974513, lr_steep=1.5848931980144698e-06)
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss rouge1 rouge2 rougeL bertscore_precision bertscore_recall bertscore_f1 time
0 1.694427 1.697661 0.295115 0.130354 0.228090 0.886885 0.862009 0.874115 02:10
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries Mexico hosts to up to 10 percent of all known species on Earth.\nIt is home to 502 types of mammals, 290 bird species and 26,000 types of plants.\nHuman development and climate change is placing a big strain on its biodiversity.\nThe Golden Eagle is un Mexico is home to up to 10 percent of all known species on the planet .\nIt is one of the most biodiverse
1 (CNN) -- It's a congested, sprawling transport hub surrounded by 1950s architecture and predominantly used by commuters or tourists to cross the city of Istanbul. But proposed changes to Taksim Square have seen it become the flashpoint for protests that have swept through Turkey in the past week, leaving thousands injured and focusing the world's attention on the government of Prime Minister Recep Tayyip Erdogan. Taksim has been no stranger to violence. In 1977, at least 34 protesters died duri Taksim Square was where Istanbul's water was distributed -- Taksim means divide.\nThe site is seen as symbolizing the seclar Turkish republic founded by Ataturk.\nErdogan's government's plans to alter Taksim's Gezi Park prompted protests.\nThe police's Taksim Square has been the flashpoint for protests that have swept through Turkey in the past week .\nIn 1977,
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
outputs = learn.blurr_generate(test_article, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 Police: 10 men with pistols and small machine guns raided a casino in Switzerland and made off with several hundred thousand Swiss francs .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police say .
Swiss authorities are working closely with French authorities, a police officer says .
Police: The robbers spoke French and drove vehicles with French lRicense plates .

=== Prediction 2 ===
 Police: 10 men with pistols and small machine guns raided a casino in Switzerland and made off with several hundred thousand Swiss francs .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
One group tried to break into the casino's vault on the lower level but could not get in, police say .
A woman driving by unknowingly blocked the robbers' vehicles with her car and was beaten to death .
There were about 600 people in the casino at the time of the robbery .

=== Prediction 3 ===
 Police: 10 men with pistols and small machine guns raided a casino in Switzerland and made off with several hundred thousand Swiss francs .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police say .
Swiss authorities are working closely with French authorities, a police officer says .

Inference

export_fname = 'summarize_export'
learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')
inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_article)
[' 10 men with pistols and small machine guns raided a casino in Switzerland and made off with several hundred thousand Swiss francs .\n']

Tests

The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

try: del learn; torch.cuda.empty_cache()
except: pass
[ model_type for model_type in BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration') 
 if (not model_type.__name__.startswith('TF')) ]
[transformers.models.bart.modeling_bart.BartForConditionalGeneration,
 transformers.models.blenderbot.modeling_blenderbot.BlenderbotForConditionalGeneration,
 transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallForConditionalGeneration,
 transformers.models.fsmt.modeling_fsmt.FSMTForConditionalGeneration,
 transformers.models.led.modeling_led.LEDForConditionalGeneration,
 transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration,
 transformers.models.mt5.modeling_mt5.MT5ForConditionalGeneration,
 transformers.models.pegasus.modeling_pegasus.PegasusForConditionalGeneration,
 transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration,
 transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
 transformers.models.xlm_prophetnet.modeling_xlm_prophetnet.XLMProphetNetForConditionalGeneration]
pretrained_model_names = [
    'facebook/bart-base',
    #'facebook/blenderbot_small-90M',
    'allenai/led-base-16384',
    'sshleifer/tiny-mbart',
    'google/mt5-small',
    'sshleifer/distill-pegasus-cnn-16-4',
    't5-small', 
    #'microsoft/prophetnet-large-uncased',
    #'microsoft/xprophetnet-large-wiki100-cased', # XLMProphetNet
]
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')
#hide_output
task = HF_TASKS_AUTO.Seq2SeqLM
bsz = 2
inp_seq_sz = 64; trg_seq_sz = 40

test_results = []
for model_name in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, task=task)
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\nmodel:\t\t{type(hf_model).__name__}\n')

    # 1. build your DataBlock
    text_gen_kwargs = {}
    if (hf_arch in ['bart', 't5']):
        text_gen_kwargs = {**hf_config.task_specific_params['summarization'], **{'max_length': 30, 'min_length': 10}}
    
    # not all "summarization" parameters are for the model.generate method ... remove them here
    generate_func_args = list(inspect.signature(hf_model.generate).parameters.keys())
    for k in text_gen_kwargs.copy():
        if k not in generate_func_args: del text_gen_kwargs[k]
            
    if (hf_arch == 'mbart'):
        text_gen_kwargs['decoder_start_token_id'] = hf_tokenizer.get_vocab()["en_XX"]
            
    tok_kwargs = {}
    if (hf_arch == 'mbart'):
        tok_kwargs['src_lang'], tok_kwargs['tgt_lang'] = "en_XX", "en_XX"
            
    def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp
    
    before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                      padding='max_length', 
                                                      max_length=inp_seq_sz, 
                                                      max_target_length=trg_seq_sz, 
                                                      tok_kwargs=tok_kwargs, text_gen_kwargs=text_gen_kwargs)
    
    blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)
    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('article'), add_t5_prefix]), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

    dls = dblock.dataloaders(cnndm_df, bs=bsz) 
    b = dls.one_batch()

    # 2. build your Learner
    seq2seq_metrics = {
        'rouge': {
            'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
            'returns': ["rouge1", "rouge2", "rougeL"]
        }
    }
    
    model = HF_BaseModelWrapper(hf_model)
    learn_cbs = [HF_BaseModelCallback]
    fit_cbs = [
        ShortEpochCallback(0.05, short_valid=True), 
        HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)
    ]
 
    learn = Learner(dls, 
                    model,
                    opt_func=ranger,
                    loss_func=HF_PreCalculatedLoss(),
                    cbs=learn_cbs,
                    splitter=partial(seq2seq_splitter, arch=hf_arch)).to_fp16()

    learn.create_opt() 
    learn.freeze()
    
    # 3. Run your tests
    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, inp_seq_sz]))
        test_eq(len(b[1]), bsz)

#         print('*** TESTING One pass through the model ***')
#         preds = learn.model(b[0])
#         test_eq(preds[1].shape[0], bsz)
#         test_eq(preds[1].shape[2], hf_config.vocab_size)

        print('*** TESTING Training/Results ***')
        learn.fit_one_cycle(1, lr_max=1e-3, cbs=fit_cbs)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'PASSED', ''))
        learn.show_results(learner=learn, max_n=2, input_trunc_at=500, target_trunc_at=250)
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'FAILED', err))
    finally:
        # cleanup
        del learn; torch.cuda.empty_cache()
arch tokenizer model_name result error
0 bart BartTokenizerFast BartForConditionalGeneration PASSED
1 led LEDTokenizerFast LEDForConditionalGeneration PASSED
2 mbart MBartTokenizerFast MBartForConditionalGeneration PASSED
3 mt5 T5TokenizerFast MT5ForConditionalGeneration PASSED
4 pegasus PegasusTokenizerFast PegasusForConditionalGeneration PASSED
5 t5 T5TokenizerFast T5ForConditionalGeneration PASSED

Cleanup