blurr
  • Getting Started
  • Resources
    • fastai x Hugging Face Study Group
    • Hugging Face Course
    • fast.ai (docs)
    • transformers (docs)
  • Help
    • Report an Issue

Modeling

  • Overview
    • Getting Started
    • callbacks
    • utils
  • Text
    • Sequence Classification
      • Data
      • Modeling
    • Token Classification
      • Data
      • Modeling
    • Question & Answering
      • Data
      • Modeling
    • Language Modeling
      • Data
      • Modeling
    • Seq2Seq: Core
      • Data
      • Modeling
    • Seq2Seq: Summarization
      • Data
      • Modeling
    • Seq2Seq: Translation
      • Data
      • Modeling
    • callbacks
    • utils
  • Examples
    • Using the high-level Blurr API
    • GLUE classification tasks
    • Using the Low-level fastai API
    • Multi-label classification
    • Causal Language Modeling with GPT-2

On this page

  • Mid-level API
    • Example
      • Training
      • Showing results
      • Prediction
    • Learner.blurr_summarize
      • Inference
  • High-level API
    • BlearnerForSummarization
      • Example
  • Tests

Report an issue

Modeling

The text.modeling.seq2seq.summarization module contains custom models, custom splitters, etc… summarization tasks.

Mid-level API

Example

The objective of summarization is to generate a concise and accurate representation of a much larger body of text. For example, we may want to summarize an article in a single sentence.

dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0", split="train[:1000]")
cnndm_df = pd.DataFrame(dataset)
cnndm_df.head(2)
Reusing dataset cnn_dailymail (/home/wgilliam/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f)
article highlights id
0 It's official: U.S. President Barack Obama wants lawmakers to weigh in on whether to use military force in Syria. Obama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military action against Syrian targets is the right step to take over the alleged use of chemical weapons. The proposed legislation from Obama asks Congress to approve the use of military force "to deter, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction." It's a step that is set to turn an internat... Syrian official: Obama climbed to the top of the tree, "doesn't know how to get down"\nObama sends a letter to the heads of the House and Senate .\nObama to seek congressional approval on military action against Syria .\nAim is to determine whether CW were used, not by whom, says U.N. spokesman . 0001d1afc246a7964130f43ae940af6bc6c57f01
1 (CNN) -- Usain Bolt rounded off the world championships Sunday by claiming his third gold in Moscow as he anchored Jamaica to victory in the men's 4x100m relay. The fastest man in the world charged clear of United States rival Justin Gatlin as the Jamaican quartet of Nesta Carter, Kemar Bailey-Cole, Nickel Ashmeade and Bolt won in 37.36 seconds. The U.S finished second in 37.56 seconds with Canada taking the bronze after Britain were disqualified for a faulty handover. The 26-year-old Bolt has now collected eight gold medals at world championships, equaling the record held by American trio... Usain Bolt wins third gold of world championship .\nAnchors Jamaica to 4x100m relay victory .\nEighth gold at the championships for Bolt .\nJamaica double up in women's 4x100m relay . 0002095e55fcbd3a2f366d9bf92a95433dc305ef
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/vocab.json from cache at /home/wgilliam/.cache/huggingface/transformers/c457182dd3c47e71636dfe957c948acf12fd6b1d17d3e16a69f9bd731f340157.647b4548b6d9ea817e82e7a9231a320231a1c9ea24053cc9e758f3fe68216f05
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/merges.txt from cache at /home/wgilliam/.cache/huggingface/transformers/1917cd1903f32920951797d984eff6fb9707c20aa7c0eba679d033d5d5dbc7d3.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/tokenizer.json from cache at None
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/tokenizer_config.json from cache at /home/wgilliam/.cache/huggingface/transformers/41a44e7ad55ba42aa9abd4697be8ff844b95c3f33ad59ceb5059b263caf581fe.67d01b18f2079bd75eac0b2f2e7235768c7f26bd728e7a855a1c5acae01a91a8
loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading weights file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/pytorch_model.bin from cache at /home/wgilliam/.cache/huggingface/transformers/b3a80b0a1380627404ab7beeafae5a22d57a6caee6d637757be7b02319a26d37.a3aeae96c9bbfd0fad6832e6f41a23b7f17b292daca2c554b8064433b145e921
All model checkpoint weights were used when initializing BartForConditionalGeneration.

All the weights of BartForConditionalGeneration were initialized from the model checkpoint at sshleifer/distilbart-cnn-6-6.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditionalGeneration for predictions without further training.
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
text_gen_kwargs = {}
if hf_arch in ["bart", "t5"]:
    text_gen_kwargs = {**hf_config.task_specific_params["summarization"], **{"max_length": 30, "min_length": 10}}

# not all "summarization" parameters are for the model.generate method ... remove them here
generate_func_args = list(inspect.signature(hf_model.generate).parameters.keys())
for k in text_gen_kwargs.copy():
    if k not in generate_func_args:
        del text_gen_kwargs[k]

if hf_arch == "mbart":
    text_gen_kwargs["decoder_start_token_id"] = hf_tokenizer.get_vocab()["en_XX"]
tok_kwargs = {}
if hf_arch == "mbart":
    tok_kwargs["src_lang"], tok_kwargs["tgt_lang"] = "en_XX", "en_XX"
batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
    hf_arch,
    hf_config,
    hf_tokenizer,
    hf_model,
    max_length=256,
    max_target_length=130,
    tok_kwargs=tok_kwargs,
    text_gen_kwargs=text_gen_kwargs,
)

blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)

dblock = DataBlock(blocks=blocks, get_x=ColReader("article"), get_y=ColReader("highlights"), splitter=RandomSplitter())
dls = dblock.dataloaders(cnndm_df, bs=2)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[1].shape
(2, torch.Size([2, 256]), torch.Size([2, 75]))
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 <s> (CNN) -- When Ji Yeqing awakened, she was already in the recovery room. Chinese authorities had dragged her out of her home and down four flights of stairs, she said, restraining and beating her husband as he tried to come to her aid. They whisked her into a clinic, held her down on a bed and forced her to undergo an abortion. Her offense? Becoming pregnant with a second child, in violation of China's one-child policy. "After the abortion, I felt empty, as if something was scooped out of me," Ji told a congressional panel in September. "My husband and I had been so excited for our new baby. Now suddenly all that hope and joy and excitement disappeared.... I was very depressed and despondent. For a long time, whenever I thought about my lost child, I would cry." As she lay unconscious, she said, an IUD to prevent future pregnancies was inserted. The issue of forced abortions -- and in some cases, forced sterilizations -- in China has seized the spotlight in recent days with news of escaped activist Chen Guangcheng. Chen, a blind, self-taught lawyer, rose to fame in the late 1990s because of his advocacy for what he calls victims</s> China's one-child policy results in forced abortions and sterilizations, activists say.\nWomen tell of emotional and physical consequences from the procedures.\nActivist Chen Guangcheng works to advocate for victims of such practices.
1 <s> (CNN) -- The generation of gays and lesbians that literally created the modern LGBT movement -- from the heroes of the 1969 Stonewall riots to their slightly younger friends -- is at, or nearing, retirement age. That used to mean the beginning of an extremely difficult time in an LGBT person's life. But as gay baby boomers find more acceptance in mainstream society and continue to do what they've always done -- push to make a better world for the LGBT community -- their retirement options are slowly improving. That is, if they decide to retire at all. "The notion of retirement has never been a part of my vocabulary," said Bob Witeck, CEO and co-founder of Witeck Communications. Nearly 61, Witeck has put some thought into what he should do with his strategic public relations and marketing firm as he gets older. Like many friends his age who are also entrepreneurs, he plans to keep working. "Because I run a business, as I get older I can change the intensity of my engagement in the kinds of work I take on," Witeck said. "I know I'm lucky that way, and I'm lucky in my personal life as well. My husband is 50, so I have a younger man to help me</s> LGBT baby boomers changed the visibility of the gay community.\nAs they approach retirement, they face different obstacles than their straight counterparts.\nWithout marriage equality, same-sex couples may face financial hardships.\nAdvocates say the situation is slowly improving.

Training

seq2seq_metrics = {
    "rouge": {
        "compute_kwargs": {"rouge_types": ["rouge1", "rouge2", "rougeL", "rougeLsum"], "use_stemmer": True},
        "returns": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
    },
    "bertscore": {"compute_kwargs": {"lang": "en"}, "returns": ["precision", "recall", "f1"]},
}
model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(
    dls,
    model,
    opt_func=partial(Adam),
    loss_func=CrossEntropyLossFlat(),
    cbs=learn_cbs,
    splitter=partial(blurr_seq2seq_splitter, arch=hf_arch),
)

# learn = learn.to_native_fp16() #.to_fp16()
learn.freeze()
learn.summary()
b = dls.one_batch()
preds = learn.model(b[0])

len(preds), preds["loss"].shape, preds["logits"].shape
(3, torch.Size([]), torch.Size([2, 59, 50264]))
len(b), len(b[0]), b[0]["input_ids"].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 256]), 2, torch.Size([2, 59]))
print(len(learn.opt.param_groups))
3
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
SuggestedLRs(minimum=4.786300996784121e-05, steep=6.309573450380412e-07, valley=6.30957365501672e-05, slide=1.4454397387453355e-05)

learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss rouge1 rouge2 rougeL rougeLsum bertscore_precision bertscore_recall bertscore_f1 time
0 2.306568 2.146496 0.296782 0.124385 0.222347 0.275367 0.892347 0.865695 0.878717 01:12
Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "_name_or_path": "roberta-large",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading file https://huggingface.co/roberta-large/resolve/main/vocab.json from cache at /home/wgilliam/.cache/huggingface/transformers/7c1ba2435b05451bc3b4da073c8dec9630b22024a65f6c41053caccf2880eb8f.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab
loading file https://huggingface.co/roberta-large/resolve/main/merges.txt from cache at /home/wgilliam/.cache/huggingface/transformers/20b5a00a80e27ae9accbe25672aba42ad2d4d4cb2c4b9359b50ca8e34e107d6d.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
loading file https://huggingface.co/roberta-large/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/roberta-large/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/roberta-large/resolve/main/tokenizer_config.json from cache at None
loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "_name_or_path": "roberta-large",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "_name_or_path": "roberta-large",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading weights file https://huggingface.co/roberta-large/resolve/main/pytorch_model.bin from cache at /home/wgilliam/.cache/huggingface/transformers/8e36ec2f5052bec1e79e139b84c2c3089cb647694ba0f4f634fec7b8258f7c89.c43841d8c5cd23c435408295164cda9525270aa42cd0cc9200911570c0342352
All the weights of RobertaModel were initialized from the model checkpoint at roberta-large.
If your task is similar to the task the model of the checkpoint was trained on, you can already use RobertaModel for predictions without further training.

Showing results

And here we create a @typedispatched implementation of Learner.show_results.

learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (CNN Student News) -- January 13, 2011. Download PDF maps related to today's show:. • Arizona • Australia. Transcript. THIS IS A RUSH TRANSCRIPT. THIS COPY MAY NOT BE IN ITS FINAL FORM AND MAY BE UPDATED. CARL AZUZ, CNN STUDENT NEWS ANCHOR: A problem that won't be solved, even if the solution is clear. The story and the reasons, leading off today's broadcast of CNN Student News! My name is Carl Azuz! First Up: Winter Storm Woes. AZUZ: Florida is the only state in the union without snow on the g A winter storm slams the northeastern United States.\nThe U.S. House of Representatives condemns the Arizona shooting.\nMassive floods leave vast areas of Australia underwater.\nUse the Daily Discussion to help students understand today's featured news [ Find out how a storm system iced out the southeast . Use the Daily Discussion to help students understand today's featured news stories, Jeb Bush and Mitt Romney are putting pressure on New Jersey Gov. Chris Christie .\nBush has been a well-liked figure]

Prediction

We add here Learner.blurr_summarize method to bring the results inline with the format returned via Hugging Face’s pipeline method

test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
outputs = learn.blurr_generate(test_article, key="summary_texts", num_return_sequences=3)
outputs
[{'summary_texts': [" 10 men armed with pistols and small machine guns raided a casino in Switzerland .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid .\nOne group tried to break into the casino's vault on the lower level but could not get in .\nA woman driving by and unaware of what was happening unknowingly blocked the armed robbers' vehicles .",
   " 10 men armed with pistols and small machine guns raided a casino in Switzerland .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid .\nOne group tried to break into the casino's vault on the lower level but could not get in .\nA woman driving by and unaware of what was happening unknowingly blocked the robbers' vehicles .",
   " 10 men armed with pistols and small machine guns raided a casino in Switzerland .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino .\nOne group tried to break into the casino's vault on the lower level but could not get in ."]}]

source

Learner.blurr_summarize

 Learner.blurr_summarize (inp, **kwargs)
learn.blurr_summarize(test_article, num_return_sequences=3)
[{'summary_texts': [" 10 men armed with pistols and small machine guns raided a casino in Switzerland .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid .\nOne group tried to break into the casino's vault on the lower level but could not get in .\nA woman driving by and unaware of what was happening unknowingly blocked the armed robbers' vehicles .",
   " 10 men armed with pistols and small machine guns raided a casino in Switzerland .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid .\nOne group tried to break into the casino's vault on the lower level but could not get in .\nA woman driving by and unaware of what was happening unknowingly blocked the robbers' vehicles .",
   " 10 men armed with pistols and small machine guns raided a casino in Switzerland .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino .\nOne group tried to break into the casino's vault on the lower level but could not get in ."]}]

Inference

Using fast.ai Learner.export and load_learner

export_fname = "summarize_export"
learn.metrics = None
learn.export(fname=f"{export_fname}.pkl")
inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_summarize(test_article)
[{'summary_texts': ' 10 men armed with pistols and small machine guns raided a casino in Switzerland .\nThe men, dressed in black clothes and black ski'}]

High-level API


source

BlearnerForSummarization

 BlearnerForSummarization (dls:fastai.data.core.DataLoaders,
                           hf_model:transformers.modeling_utils.PreTrained
                           Model, base_model_cb:blurr.text.modeling.core.B
                           aseModelCallback=<class
                           'blurr.text.modeling.core.BaseModelCallback'>,
                           loss_func:callable|None=None,
                           opt_func=<function Adam>, lr=0.001,
                           splitter:callable=<function trainable_params>,
                           cbs=None, metrics=None, path=None,
                           model_dir='models', wd=None, wd_bn_bias=False,
                           train_bn=True, moms=(0.95, 0.85, 0.95),
                           default_cbs:bool=True)

Group together a model, some dls and a loss_func to handle training

Type Details
dls DataLoaders containing data for each dataset needed for model
hf_model PreTrainedModel

Example

learn = BlearnerForSummarization.from_data(
    cnndm_df,
    "sshleifer/distilbart-cnn-6-6",
    text_attr="article",
    summary_attr="highlights",
    max_length=256,
    max_target_length=130,
    dblock_splitter=RandomSplitter(),
    dl_kwargs={"bs": 2},
).to_fp16()
loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading weights file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/pytorch_model.bin from cache at /home/wgilliam/.cache/huggingface/transformers/b3a80b0a1380627404ab7beeafae5a22d57a6caee6d637757be7b02319a26d37.a3aeae96c9bbfd0fad6832e6f41a23b7f17b292daca2c554b8064433b145e921
All model checkpoint weights were used when initializing BartForConditionalGeneration.

All the weights of BartForConditionalGeneration were initialized from the model checkpoint at sshleifer/distilbart-cnn-6-6.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditionalGeneration for predictions without further training.
loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/vocab.json from cache at /home/wgilliam/.cache/huggingface/transformers/c457182dd3c47e71636dfe957c948acf12fd6b1d17d3e16a69f9bd731f340157.647b4548b6d9ea817e82e7a9231a320231a1c9ea24053cc9e758f3fe68216f05
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/merges.txt from cache at /home/wgilliam/.cache/huggingface/transformers/1917cd1903f32920951797d984eff6fb9707c20aa7c0eba679d033d5d5dbc7d3.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/tokenizer.json from cache at None
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/tokenizer_config.json from cache at /home/wgilliam/.cache/huggingface/transformers/41a44e7ad55ba42aa9abd4697be8ff844b95c3f33ad59ceb5059b263caf581fe.67d01b18f2079bd75eac0b2f2e7235768c7f26bd728e7a855a1c5acae01a91a8
loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading configuration file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/98e51ece807bb08f235356791c26c1d775cc56c394304f0ddf1809c6bc45b391.a394a5757192281a4f3940a7ccf20051a750f630dd86fffbaa84d8cff7a0d496
Model config BartConfig {
  "_name_or_path": "sshleifer/distilbart-cnn-6-6",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_repeat_ngram_size": 3,
  "normalize_before": false,
  "normalize_embedding": true,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 1,
  "prefix": " ",
  "replacing_rate": 0,
  "scale_embedding": false,
  "static_position_embeddings": false,
  "student_decoder_layers": null,
  "student_encoder_layers": null,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 142,
      "min_length": 56,
      "no_repeat_ngram_size": 3,
      "num_beams": 4
    }
  },
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50264
}

loading weights file https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/pytorch_model.bin from cache at /home/wgilliam/.cache/huggingface/transformers/b3a80b0a1380627404ab7beeafae5a22d57a6caee6d637757be7b02319a26d37.a3aeae96c9bbfd0fad6832e6f41a23b7f17b292daca2c554b8064433b145e921
All model checkpoint weights were used when initializing BartForConditionalGeneration.

All the weights of BartForConditionalGeneration were initialized from the model checkpoint at sshleifer/distilbart-cnn-6-6.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditionalGeneration for predictions without further training.
learn.fit_one_cycle(1, lr_max=4e-5, cbs=[BlearnerForSummarization.get_metrics_cb()])
epoch train_loss valid_loss rouge1 rouge2 rougeL rougeLsum bertscore_precision bertscore_recall bertscore_f1 time
0 2.217999 2.165818 0.363228 0.142654 0.249067 0.338458 0.879375 0.888943 0.884054 02:44
Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "_name_or_path": "roberta-large",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading file https://huggingface.co/roberta-large/resolve/main/vocab.json from cache at /home/wgilliam/.cache/huggingface/transformers/7c1ba2435b05451bc3b4da073c8dec9630b22024a65f6c41053caccf2880eb8f.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab
loading file https://huggingface.co/roberta-large/resolve/main/merges.txt from cache at /home/wgilliam/.cache/huggingface/transformers/20b5a00a80e27ae9accbe25672aba42ad2d4d4cb2c4b9359b50ca8e34e107d6d.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
loading file https://huggingface.co/roberta-large/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/roberta-large/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/roberta-large/resolve/main/tokenizer_config.json from cache at None
loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "_name_or_path": "roberta-large",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "_name_or_path": "roberta-large",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading weights file https://huggingface.co/roberta-large/resolve/main/pytorch_model.bin from cache at /home/wgilliam/.cache/huggingface/transformers/8e36ec2f5052bec1e79e139b84c2c3089cb647694ba0f4f634fec7b8258f7c89.c43841d8c5cd23c435408295164cda9525270aa42cd0cc9200911570c0342352
All the weights of RobertaModel were initialized from the model checkpoint at roberta-large.
If your task is similar to the task the model of the checkpoint was trained on, you can already use RobertaModel for predictions without further training.
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (CNN) -- When Ji Yeqing awakened, she was already in the recovery room. Chinese authorities had dragged her out of her home and down four flights of stairs, she said, restraining and beating her husband as he tried to come to her aid. They whisked her into a clinic, held her down on a bed and forced her to undergo an abortion. Her offense? Becoming pregnant with a second child, in violation of China's one-child policy. "After the abortion, I felt empty, as if something was scooped out of me," J China's one-child policy results in forced abortions and sterilizations, activists say.\nWomen tell of emotional and physical consequences from the procedures.\nActivist Chen Guangcheng works to advocate for victims of such practices. [ Ji Yeqing says she was forced to have an abortion in violation of China's one-child policy .\nShe says she felt "empty" after the abortion .\nThe issue of forced abortions in China has seized the spotlight in recent days .\nIn some cases, forced sterilizations are used to prevent future pregnancies ., Malala Yousufzai was shot in the neck by Taliban militants on Tuesday .\nMalala is recovering after surgeons worked for three hours to remove a bullet lodged in her neck .\nAn angry chorus of voices in social media, on the street, and over the airwaves decries the attack .\nThe 14-year-old is a defiant blogger .]
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
learn.predict(test_article, num_return_sequences=3)
[{'summary_texts': [" 10 men raid Swiss casino in early hours of Sunday morning, police say .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid .\nOne group tried to break into the vault on the lower level but could not get in .\nA woman driving by and unaware of what was happening unknowingly blocked the armed robbers' vehicles .",
   " 10 men raid Swiss casino in early hours of Sunday morning, police say .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid .\nOne group tried to break into the vault on the lower level but could not get in .\nA woman driving by and unaware of what was happening unknowingly blocked the robbers' vehicles .",
   " 10 men raid Swiss casino in early hours of Sunday morning, police say .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid .\nOne group tried to break into the vault on the lower level but could not get in .\nA woman driving by and unaware of what was happening unknowingly blocked the robbers' vehicles .\n"]}]
export_fname = "summarize_export"

learn.metrics = None
learn = learn.to_fp32()
learn.export(fname=f"{export_fname}.pkl")

inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_summarize(test_article)
[{'summary_texts': " 10 men raid Swiss casino in early hours of Sunday morning, police say .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid .\nOne group tried to break into the vault on the lower level but could not get in .\nA woman driving by and unaware of what was happening unknowingly blocked the armed robbers' vehicles ."}]

Tests

The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with … and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you’d like to fix it yourself)

arch tokenizer model_name result error
0 bart BartTokenizerFast BartForConditionalGeneration PASSED
1 led LEDTokenizerFast LEDForConditionalGeneration PASSED
2 mbart MBartTokenizerFast MBartForConditionalGeneration PASSED
3 mt5 T5TokenizerFast MT5ForConditionalGeneration PASSED
4 pegasus PegasusTokenizerFast PegasusForConditionalGeneration PASSED
5 t5 T5TokenizerFast T5ForConditionalGeneration PASSED