This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5.
 
What we're running with at the time this documentation was generated:
torch: 1.9.0+cu102
fastai: 2.5.2
transformers: 4.10.0

Summarization tokenization, batch transform, and DataBlock methods

Summarization tasks attempt to generate a human-understandable and sensible representation of a larger body of text (e.g., capture the meaning of a larger document in 1-3 sentences).

path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)
1000
cnndm_df.head(2)
article highlights ds_type
0 (CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el... John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in "century of knowledge" train
1 (CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will "hopefully bring some order" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ... NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says . train
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=model_cls)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
('bart',
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
blocks = (HF_Seq2SeqBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

Two lines! Notice we pass in noop for our targets (e.g. our summaries) because the batch transform will take care of both out inputs and targets.

 
dls = dblock.dataloaders(cnndm_df, bs=4)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[0]['labels'].shape, b[1].shape
(2, torch.Size([4, 1024]), torch.Size([4, 78]), torch.Size([4, 78]))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000, target_trunc_at=250)
text target
0 <s> Dan Condon believes in recycling. Just not when it comes to his hotel towels. Condon composts when he's at home in Boulder, Colorado. He eats local, organic and fair-trade food and drives a Honda CR-Z hybrid sports car. You might call him green. Except he's not so green when he travels for his work at an education nonprofit and stays in a hotel, which happens about 10 weeks per year. There, he uses a new towel every day. And don't try to bribe him with a drink or dessert coupon to get him to reuse the same one. "I could care less about rewards for environmentally conscious behavior unless it's miles," Condon wrote in an e-mail. If hotels can't convince a hybrid-driving recycling enthusiast like Condon to go green while traveling, how can they possibly convince everyone else? 9 glamorous movie-star hotels. That's the problem of hotels trying to "green" your hotel stay. After guests have paid a pretty penny for a night at the inn, even the most environmental guests may want to treat Hotel guests who "go green" are happier with their stay.\nIncreasing water and energy costs are pushing hotels to cut costs wherever they can.\nMany hotels find that guests don't mind using the same towels and sheets every night.\nTripAdvisor will be a
1 <s> (CNN) -- Two weeks. Two gut-wrenching, frustrating, mysterious weeks. That's how long it's been since 227 passengers and 12 crew members boarded Malaysia Airlines Flight 370, destined for Beijing. A routine trip, it seemed, to catch up relatives in time for the weekend, start on a work assignment or just get away. Where they got to, still unknown. An exhaustive search -- covering a mind-boggling 2.97 million square miles, which is nearly the size of the continental United States -- has yielded some clues, but no proof of where the Boeing 777 is or definitively what happened to it. The latest, most notable lead revolved around two large objects detected by satellite Sunday floating on waters over 1,400 miles off of Australia's west coast. The first of several Australian military planes, as well as two long-range commercial jets, resumed their search Saturday morning to find any trace of the objects, amid some skepticism that they or ships in the area ever will and, if they do, that NEW: Planes depart Australia to resume their search for airplane debris.\nNEW: Official: Passengers' relatives are moved to a different Kuala Lumpur hotel.\nObjects seen on satellite spark intensive search in southern Indian Ocean.\nU.S. officials: Fil

Tests

The purpose of the following tests is to ensure as much as possible, that the core DataBlock code above works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

[ model_type for model_type in BLURR.get_models(task='ConditionalGeneration') 
 if (not model_type.startswith('TF')) ]
['BartForConditionalGeneration',
 'BigBirdPegasusForConditionalGeneration',
 'BlenderbotForConditionalGeneration',
 'BlenderbotSmallForConditionalGeneration',
 'FSMTForConditionalGeneration',
 'LEDForConditionalGeneration',
 'M2M100ForConditionalGeneration',
 'MBartForConditionalGeneration',
 'MT5ForConditionalGeneration',
 'PegasusForConditionalGeneration',
 'ProphetNetForConditionalGeneration',
 'Speech2TextForConditionalGeneration',
 'T5ForConditionalGeneration',
 'XLMProphetNetForConditionalGeneration']
pretrained_model_names = [
    'facebook/bart-base',
    'facebook/blenderbot_small-90M',
    'allenai/led-base-16384',
    'google/mt5-small',
    'google/pegasus-cnn_dailymail',
    't5-small', 
    'microsoft/prophetnet-large-uncased',
    'microsoft/xprophetnet-large-wiki100-cased', # XLMProphetNet
]
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')
#hide_output
model_cls = AutoModelForSeq2SeqLM
bsz = 2
seq_sz = 256
trg_seq_sz = 40

test_results = []
for model_name in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(model_name, model_cls=model_cls)
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n')
    
    # not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
    if (hf_tokenizer.pad_token is None): 
        hf_tokenizer.add_special_tokens({'pad_token': '<pad>'})  
        hf_config.pad_token_id = hf_tokenizer.get_vocab()['<pad>']
        hf_model.resize_token_embeddings(len(hf_tokenizer))   
    
    before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                      padding='max_length', 
                                                      max_length=seq_sz, 
                                                      max_target_length=trg_seq_sz)
    
    def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp
    
    blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)
    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('article'), add_t5_prefix]), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

    dls = dblock.dataloaders(cnndm_df, bs=bsz) 
    b = dls.one_batch()

    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, seq_sz]))
        test_eq(len(b[1]), bsz)
        test_eq(b[1].shape, torch.Size([bsz, trg_seq_sz]))

        if (hasattr(hf_tokenizer, 'add_prefix_space') and hf_arch not in ['led']):
             test_eq(hf_tokenizer.add_prefix_space, True)
            
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'PASSED', ''))
        dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000)
        
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'FAILED', err))
arch tokenizer model_name result error
0 bart BartTokenizerFast facebook/bart-base PASSED
1 blenderbot_small BlenderbotSmallTokenizer facebook/blenderbot_small-90M PASSED
2 led LEDTokenizerFast allenai/led-base-16384 PASSED
3 mt5 T5TokenizerFast google/mt5-small PASSED
4 pegasus PegasusTokenizerFast google/pegasus-cnn_dailymail PASSED
5 t5 T5TokenizerFast t5-small PASSED
6 prophetnet ProphetNetTokenizer microsoft/prophetnet-large-uncased PASSED
7 xlm_prophetnet XLMProphetNetTokenizer microsoft/xprophetnet-large-wiki100-cased PASSED

Summary

This module includes the fundamental data preprocessing bits to use Blurr for summarization.