blurr
  • Getting Started
  • Resources
    • fastai x Hugging Face Study Group
    • Hugging Face Course
    • fast.ai (docs)
    • transformers (docs)
  • Help
    • Report an Issue

Data

  • Overview
    • Getting Started
    • callbacks
    • utils
  • Text
    • Sequence Classification
      • Data
      • Modeling
    • Token Classification
      • Data
      • Modeling
    • Question & Answering
      • Data
      • Modeling
    • Language Modeling
      • Data
      • Modeling
    • Seq2Seq: Core
      • Data
      • Modeling
    • Seq2Seq: Summarization
      • Data
      • Modeling
    • Seq2Seq: Translation
      • Data
      • Modeling
    • callbacks
    • utils
  • Examples
    • Using the high-level Blurr API
    • GLUE classification tasks
    • Using the Low-level fastai API
    • Multi-label classification
    • Causal Language Modeling with GPT-2

On this page

  • Setup
  • Preprocessing
    • SummarizationPreprocessor
      • Using a DataFrame
  • Examples
    • Using the mid-level API
      • Batch-Time Tokenization
      • Using a preprocessed dataset
  • Tests

Report an issue

Data

The text.data.seq2seq.summarization module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5. Summarization tasks attempt to generate a human-understandable and sensible representation of a larger body of text (e.g., capture the meaning of a larger document in 1-3 sentences).

Setup

We’ll use a subset of cnn_dailymail to demonstrate how to configure your BLURR for summarization tasks

raw_datasets = load_dataset("cnn_dailymail", "3.0.0", split=["train", "validation"])
raw_datasets
Reusing dataset cnn_dailymail (/home/wgilliam/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)
[Dataset({
     features: ['article', 'highlights', 'id'],
     num_rows: 287113
 }),
 Dataset({
     features: ['article', 'highlights', 'id'],
     num_rows: 13368
 })]
print(raw_datasets[0][0].keys())
print(raw_datasets[0][0]["highlights"])

print(raw_datasets[1][0].keys())
print(raw_datasets[1][0]["highlights"])
dict_keys(['article', 'highlights', 'id'])
Syrian official: Obama climbed to the top of the tree, "doesn't know how to get down"
Obama sends a letter to the heads of the House and Senate .
Obama to seek congressional approval on military action against Syria .
Aim is to determine whether CW were used, not by whom, says U.N. spokesman .
dict_keys(['article', 'highlights', 'id'])
Accident happens in Santa Ynez, California, near where Crosby lives .
The jogger suffered multiple fractures; his injuries are not believed to be life-threatening .
raw_train_ds = raw_datasets[0].shuffle(seed=42).select(range(1000))
raw_valid_ds = raw_datasets[1].shuffle(seed=42).select(range(200))

len(raw_train_ds) + len(raw_valid_ds)
Loading cached shuffled indices for dataset at /home/wgilliam/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234/cache-516bef66c83f0d37.arrow
Loading cached shuffled indices for dataset at /home/wgilliam/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234/cache-e7e93c0052828394.arrow
1200
raw_train_df = pd.DataFrame(raw_train_ds)
raw_valid_df = pd.DataFrame(raw_valid_ds)

raw_train_df.head(2)
article highlights id
0 A protester in Ferguson was arrested during a demonstration on Thursday night - and live-tweeted her entire experience. Brittany Ferrell, a nursing student at the University of Missouri-Saint Louis, was one of 13 people detained by officers in the conflicted Missouri city for 'noise disruption'. The detention has sparked an investigation by the American Civil Liberties Union as lawyers accuse officers of overstretching their powers. Scroll down for video . Arrested: This is Brittany Ferrell, the nursing student and protester who live-tweeted her arrest in Ferguson . Tweeting in handcuffs, ... Brittany Ferrell, nursing student, was arrested with 12 people on Thursday .\nThey were calling on police take responsibility for Michael Brown's death .\nMs Ferrell tweeted as she was arrested, piled in a small wagon with 7 others .\nThey were accused of 'noise disruption', put in orange jumpsuits and cuffed .\nOfficers now being investigated, lawyers claim they 'overstretched powers' 1e01f238418c31d4e9093f6334e0232babeb639a
1 A day after confirming it had lost the ability to display Instagram images, Twitter has rolled out its own library of retro filters for its Android and iPhone apps. The eight filters are the usual suspects we've come to expect from mobile photo apps, including desaturated, black and white and high contrast. There are auto-adjust and cropping options, as well as a helpful grid view that lets you see what each filter will look like at once. "The latest versions of Twitter for iPhone and Twitter for Android introduce a few new ways to enhance the images you tweet," said Twitter senior designe... Twitter has added photo filters to its Android and iOS mobile apps .\nThe addition will help Twitter compete against Facebook-owned Instagram .\nThis is the first time the social network has offered image editing tools . 6f89645bff243fe9ce2a0509e5ca01912abf0d10
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
('bart',
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)

Preprocessing

Starting with version 2.0, BLURR provides a preprocessing base class that can be used to build task specific pre-processed datasets from pandas DataFrames or Hugging Face Datasets


source

SummarizationPreprocessor

 SummarizationPreprocessor (hf_tokenizer:transformers.tokenization_utils_b
                            ase.PreTrainedTokenizerBase,
                            batch_size:int=1000,
                            id_attr:Optional[str]=None,
                            text_attr:str='text',
                            max_input_tok_length:Optional[int]=None,
                            target_text_attr:str='summary',
                            max_target_tok_length:Optional[int]=None,
                            min_summary_char_length:Optional[int]=None,
                            is_valid_attr:Optional[str]='is_valid',
                            tok_kwargs:dict={})

Initialize self. See help(type(self)) for accurate signature.

Type Default Details
hf_tokenizer PreTrainedTokenizerBase A Hugging Face tokenizer
batch_size int 1000 The number of examples to process at a time
id_attr Optional None The unique identifier in the dataset
text_attr str text The attribute holding the text
max_input_tok_length Optional None The maximum length (# of tokens) allowed for inputs. Will default to the max length allowed
by the model if not provided
target_text_attr str summary The attribute holding the summary
max_target_tok_length Optional None The maximum length (# of tokens) allowed for targets
min_summary_char_length Optional None If not “None”, any examples where “target_text_attr” is < “min_summary_char_length” will be removed
is_valid_attr Optional is_valid The attribute that should be created if your are processing individual training and validation
datasets into a single dataset, and will indicate to which each example is associated
tok_kwargs dict {} Tokenization kwargs that will be applied with calling the tokenizer

This class can be used for preprocessing summarization tasks, and includes a proc_{your_text_attr} and proc_{target_text_attr} attributes containing your modified input and target texts as a result of tokenization (e.g., if you specify a max_length the proc_{your_text_attr} may contain truncated text).

Using a DataFrame

preprocessor = SummarizationPreprocessor(
    hf_tokenizer,
    id_attr="id",
    text_attr="article",
    target_text_attr="highlights",
    max_input_tok_length=128,
    max_target_tok_length=30,
    min_summary_char_length=10,
)
proc_df = preprocessor.process_df(raw_train_df, raw_valid_df)
proc_df.columns, len(proc_df)
proc_df.head(2)
proc_highlights proc_article article highlights id is_valid article_start_char_idx article_end_char_idx highlights_start_char_idx highlights_end_char_idx
0 Brittany Ferrell, nursing student, was arrested with 12 people on Thursday .\nThey were calling on police take responsibility for Michael Brown's death A protester in Ferguson was arrested during a demonstration on Thursday night - and live-tweeted her entire experience. Brittany Ferrell, a nursing student at the University of Missouri-Saint Louis, was one of 13 people detained by officers in the conflicted Missouri city for 'noise disruption'. The detention has sparked an investigation by the American Civil Liberties Union as lawyers accuse officers of overstretching their powers. Scroll down for video . Arrested: This is Brittany Ferrell, the nursing student and protester who live-tweeted her arrest in Ferguson . Tweeting in handcuffs, ... A protester in Ferguson was arrested during a demonstration on Thursday night - and live-tweeted her entire experience. Brittany Ferrell, a nursing student at the University of Missouri-Saint Louis, was one of 13 people detained by officers in the conflicted Missouri city for 'noise disruption'. The detention has sparked an investigation by the American Civil Liberties Union as lawyers accuse officers of overstretching their powers. Scroll down for video . Arrested: This is Brittany Ferrell, the nursing student and protester who live-tweeted her arrest in Ferguson . Tweeting in handcuffs, ... Brittany Ferrell, nursing student, was arrested with 12 people on Thursday .\nThey were calling on police take responsibility for Michael Brown's death .\nMs Ferrell tweeted as she was arrested, piled in a small wagon with 7 others .\nThey were accused of 'noise disruption', put in orange jumpsuits and cuffed .\nOfficers now being investigated, lawyers claim they 'overstretched powers' 1e01f238418c31d4e9093f6334e0232babeb639a False 0 648 0 150
1 Twitter has added photo filters to its Android and iOS mobile apps .\nThe addition will help Twitter compete against Facebook-owned Instagram .\nThis A day after confirming it had lost the ability to display Instagram images, Twitter has rolled out its own library of retro filters for its Android and iPhone apps. The eight filters are the usual suspects we've come to expect from mobile photo apps, including desaturated, black and white and high contrast. There are auto-adjust and cropping options, as well as a helpful grid view that lets you see what each filter will look like at once. "The latest versions of Twitter for iPhone and Twitter for Android introduce a few new ways to enhance the images you tweet," said Twitter senior designe... A day after confirming it had lost the ability to display Instagram images, Twitter has rolled out its own library of retro filters for its Android and iPhone apps. The eight filters are the usual suspects we've come to expect from mobile photo apps, including desaturated, black and white and high contrast. There are auto-adjust and cropping options, as well as a helpful grid view that lets you see what each filter will look like at once. "The latest versions of Twitter for iPhone and Twitter for Android introduce a few new ways to enhance the images you tweet," said Twitter senior designe... Twitter has added photo filters to its Android and iOS mobile apps .\nThe addition will help Twitter compete against Facebook-owned Instagram .\nThis is the first time the social network has offered image editing tools . 6f89645bff243fe9ce2a0509e5ca01912abf0d10 False 0 635 0 147

Examples

Using the mid-level API

Batch-Time Tokenization

Step 1: Get your Hugging Face objects.
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
Step 2: Create your DataBlock

Two lines! Notice we pass in noop for our targets (e.g. our summaries) because the batch transform will take care of both out inputs and targets.

blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("article"), get_y=ColReader("highlights"), splitter=RandomSplitter())
# dblock.summary(cnndm_df)
Step 3: Build your DataLoaders
dls = dblock.dataloaders(raw_train_df, bs=4)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[0]["labels"].shape, b[1].shape
(2, torch.Size([4, 1024]), torch.Size([4, 152]), torch.Size([4, 152]))
b[0]["labels"][0], b[1][0]
(tensor([    0,   270,  3905,  2950,   516,     9,   908,    25,    37,  5586,
           940,  2355,   375,   479, 50118,  9167,   703,    15,     5,   276,
           183,  1284,  2922, 11137,  4457,    30,   299,   940,  2355,  3504,
            11,   188,   469,   479,     2,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100], device='cuda:1'),
 tensor([    0,   270,  3905,  2950,   516,     9,   908,    25,    37,  5586,
           940,  2355,   375,   479, 50118,  9167,   703,    15,     5,   276,
           183,  1284,  2922, 11137,  4457,    30,   299,   940,  2355,  3504,
            11,   188,   469,   479,     2,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100], device='cuda:1'))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000, target_trunc_at=250)
text target
0 <s> By. Daily Mail Reporter. PUBLISHED:. 08:16 EST, 14 May 2012. |. UPDATED:. 22:07 EST, 14 May 2012. Barack Obama's latest campaign gambit follows a familiar line of attack as it uses Mitt Romney's private equity past to cast the Republican candidate as greedy, job-killing corporate titan with little concern for the working class. The President is not the first of Mr Romney's opponents to try and paint the former governor of Massachusetts as a heartless uber-capitalist - even his Republican rivals used the same tactic during the heated primary battle. But Mr Obama's campaign seems to have been particularly unoriginal - as his attack ad is almost identical to one produced by Ted Kennedy for his Senate campaign against Mr Romney in 1994, featuring unemployed workers complaining about Bain Capital, the firm founded by Mr Romney. The timing of the Obama assault on private equity is also unfortunate, as on Monday night the President attended a fundraiser hosted by Democratic supporter Ham President follows familiar line of attack as he highlights private equity past.\nAd released on the same day Obama attended fundraiser hosted by top private equity boss in New York.
1 <s> (CNN) -- Voters in North Carolina, Indiana and Ohio on Tuesday kick off five straight weeks of primary contests that could give us a clearer indication of whether establishment Republicans have the upper hand against the tea party movement for control of the party. The results could back up recent tough talk from Senate GOP leader Mitch McConnell, who predicted big wins for incumbents facing primary challenges from the right, saying, "I think we are going to crush them everywhere." And they may have a major impact in determining whether Republicans retake the majority in the Senate. Since the birth of the tea party movement in 2009, primary challenges from the right have produced major headlines and headaches for the GOP and hurt the party's chances of winning back the Senate from Democrats in the past two election cycles. Candidates backed by the tea party movement and other grass-roots conservatives effectively cost the GOP five winnable Senate elections the last two cycles in Ne Establishment Republicans are fighting back more strongly against challenges from the right.\nWith a number of vulnerable Democrats in the Senate, GOP thinks it can win control.\nNorth Carolina primary seen as a key test of establishment-vs.-tea party

Using a preprocessed dataset

Step 1a: Get your Hugging Face objects.
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
Step 1b. Preprocess dataset
preprocessor = SummarizationPreprocessor(
    hf_tokenizer,
    id_attr="id",
    text_attr="article",
    target_text_attr="highlights",
    max_input_tok_length=128,
    max_target_tok_length=30,
    min_summary_char_length=10,
)
proc_df = preprocessor.process_df(raw_train_df, raw_valid_df)
Step 2: Create your DataBlock
blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("proc_article"), get_y=ColReader("proc_highlights"), splitter=ColSplitter())
Step 3: Build your DataLoaders
dls = dblock.dataloaders(proc_df, bs=4)
dls.show_batch(dataloaders=dls, max_n=2, trunc_at=500)
text target
0 <s> Washington (CNN) -- A post-mortem Sunday of the mid-term elections provided little evidence that Democrats and Republicans will work together to address major issues such as deficit reduction any better than they have in recent years. Republicans interviewed on talk shows promised congressional investigations, an all-out effort to repeal health care reform, and steadfast opposition to any form of higher taxes. Democrats, meanwhile, said the losses they suffered in the congressional elections reflected voter dissatisfaction with lingering high unemployment in the slow recovery from economic recession, rather than an outright repudiation of their policies. Republicans won more than 60 seats formerly held by Democrats to take majority control of </s> GOP targets health care reform, government spending.\n"Are we willing to work with him?" Cantor says of President Obama.\nObama says
1 <s> Scientists believe they have discovered how to'switch off' autoimmune diseases, prompting hope the breakthrough could pave the way for a new treatment for multiple sclerosis. Researchers at the University of Bristol, who describe the work as an 'important breakthrough', say it could improve the lives of millions around the world. The study reveals how to stop cells from attacking healthy body tissue. The team discovered how cells convert from being aggressive to protecting against disease, rather than the body's immune system destroying its own tissue by mistake. Scientists at the University of Bristol have discovered how to'switch off' autoimmune diseases, which they hope will pave the way for new </s> Team at Bristol University have described their work as a 'breakthrough'\nDiscovered a way to stop cells from attacking healthy body tissue.\n

Tests

The purpose of the following tests is to ensure as much as possible, that the core DataBlock code above works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with … and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you’d like to fix it yourself)

[model_type for model_type in NLP.get_models(task="ConditionalGeneration") if (not model_type.startswith("TF"))]
['BartForConditionalGeneration',
 'BigBirdPegasusForConditionalGeneration',
 'BlenderbotForConditionalGeneration',
 'BlenderbotSmallForConditionalGeneration',
 'FSMTForConditionalGeneration',
 'LEDForConditionalGeneration',
 'M2M100ForConditionalGeneration',
 'MBartForConditionalGeneration',
 'MT5ForConditionalGeneration',
 'PegasusForConditionalGeneration',
 'ProphetNetForConditionalGeneration',
 'Speech2TextForConditionalGeneration',
 'T5ForConditionalGeneration',
 'XLMProphetNetForConditionalGeneration']
pretrained_model_names = [
    "facebook/bart-base",
    "facebook/blenderbot_small-90M",
    "allenai/led-base-16384",
    "google/mt5-small",
    "google/pegasus-cnn_dailymail",
    "t5-small",
    "microsoft/prophetnet-large-uncased",
    "microsoft/xprophetnet-large-wiki100-cased",  # XLMProphetNet
]
path = Path("./")
cnndm_df = pd.read_csv(path / "cnndm_sample.csv")
model_cls = AutoModelForSeq2SeqLM
bsz = 2
seq_sz = 256
trg_seq_sz = 40

test_results = []
for model_name in pretrained_model_names:
    error = None

    print(f"=== {model_name} ===\n")

    hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(model_name, model_cls=model_cls)
    print(f"architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n")

    # not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
    if hf_tokenizer.pad_token is None:
        hf_tokenizer.add_special_tokens({"pad_token": "<pad>"})
        hf_config.pad_token_id = hf_tokenizer.get_vocab()["<pad>"]
        hf_model.resize_token_embeddings(len(hf_tokenizer))

    batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
        hf_arch, hf_config, hf_tokenizer, hf_model, padding="max_length", max_length=seq_sz, max_target_length=trg_seq_sz
    )

    def add_t5_prefix(inp):
        return f"summarize: {inp}" if (hf_arch == "t5") else inp

    blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)
    dblock = DataBlock(
        blocks=blocks, get_x=Pipeline([ColReader("article"), add_t5_prefix]), get_y=ColReader("highlights"), splitter=RandomSplitter()
    )

    dls = dblock.dataloaders(cnndm_df, bs=bsz)
    b = dls.one_batch()

    try:
        print("*** TESTING DataLoaders ***\n")
        test_eq(len(b), 2)
        test_eq(len(b[0]["input_ids"]), bsz)
        test_eq(b[0]["input_ids"].shape, torch.Size([bsz, seq_sz]))
        test_eq(len(b[1]), bsz)
        test_eq(b[1].shape, torch.Size([bsz, trg_seq_sz]))

        if hasattr(hf_tokenizer, "add_prefix_space") and hf_arch not in ["led"]:
            test_eq(hf_tokenizer.add_prefix_space, True)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, "PASSED", ""))
        dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000)

    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, "FAILED", err))
arch tokenizer model_name result error
0 bart BartTokenizerFast facebook/bart-base PASSED
1 blenderbot_small BlenderbotSmallTokenizer facebook/blenderbot_small-90M PASSED
2 led LEDTokenizerFast allenai/led-base-16384 PASSED
3 mt5 T5TokenizerFast google/mt5-small PASSED
4 pegasus PegasusTokenizerFast google/pegasus-cnn_dailymail PASSED
5 t5 T5TokenizerFast t5-small PASSED
6 prophetnet ProphetNetTokenizer microsoft/prophetnet-large-uncased PASSED
7 xlm_prophetnet XLMProphetNetTokenizer microsoft/xprophetnet-large-wiki100-cased PASSED