This module contains core custom models, loss functions, etc... for Seq2Seq based tasks (e.g., language modeling, summarization, translation, etc...)
 
What we're running with at the time this documentation was generated:
torch: 1.10.0+cu102
fastai: 2.5.3
transformers: 4.12.5

Seq2Seq

path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')

cnndm_df.head(2)
article highlights ds_type
0 (CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el... John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in "century of knowledge" train
1 (CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will "hopefully bring some order" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ... NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says . train
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, 
                                                                  model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                  max_length=256, max_target_length=130)

blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
dls = dblock.dataloaders(cnndm_df, bs=2)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([2, 256]), torch.Size([2, 69]))
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 <s> (CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries that harbor the majority of the Earth's species and are therefore considered extremely biodiverse. From its coral reefs in the Caribbean Sea to its tropical jungles in Chiapas and the Yucatan peninsula and its deserts and prairies in the north, Mexico boasts an incredibly rich variety of flora and fauna. Some 574 out of 717 reptile species found in Mexico -- the most in any country -- can only be encountered within its borders. It is home to 502 types of mammals, 290 species of birds, 1,150 varieties of birds and 26,000 classifications of plants. Pronatura, a non-profit organization that works to promote conservation and sustainable development in Mexico, has selected six species which it says symbolize the problems faced by the</s> Mexico hosts to up to 10 percent of all known species on Earth.\nIt is home to 502 types of mammals, 290 bird species and 26,000 types of plants.\nHuman development and climate change is placing a big strain on its biodiversity.\nThe Golden Eagle is under threat in spite of being the country's national symbol.
1 <s> Washington (CNN)Almost immediately following the news of the first terrorist attacks that eventually killed 17 people across France, the global community united around a Twitter hashtag "Je suis Charlie" and just days later foreign leaders linked arms with their French counterparts to lead a historic million-person strong rally. Meanwhile, explosives strapped to a girl who appeared to be about 10-years-old detonated on Saturday, killing at least 20 people, in a country whose encounters with terrorism were also punctuated by a hashtag -- this time "#BringBackOurGirls" of Nigeria. Boko Haram militants killed as many as 2,000 people, mostly civilians,in a massacre that started the weekend before the terror attack on Charlie Hedbo in downtown Paris. Both the attacks in Nigeria and those in Paris are shocking and horrifying in their own respects, and yet one fomented an unprecedented international reaction -- a popular show of force that rivaled even the reaction to 9/11 -- while the response to the attacks in Nigeria paled in comparison. Here are a few of the reasons why:. Symbolism. The terrorist attack on the satirical publication Charlie Hebdo was not just violent, but highly symbolic. While the terrorists in Nigeria targeted innocent civilians in a strategic northern town in Nigeria and in a crowded marketplace</s> France and Nigeria experienced waves of terrorism during the first weeks of 2015.\nWhile the terror attacks in Paris sparked international unified outrage, reaction to Nigeria was more muted.\nSymbolism, politics and media all played a role in how France's response to terrorism was perceived.

Training

Here we create a Seq2Seq specific subclass of HF_BaseModelCallback in order to include custom, Seq2Seq specific, metrics, and also handle the pre-calculated loss during training

seq2seq_metrics

  • {'rouge': { 'compute_args': {'return_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True}, 'returns':["rouge1", "rouge2", "rougeL"]}
  • {'bert_score': { 'returns': ["precision", "recall", "f1"] }
  • {'bleu': { 'returns': "bleu" }
  • {'bleurt': { 'returns': "scores" }
  • {'meteor': { 'returns': "meteor" }
  • {'sacrebleu': { 'returns': "score" }

class HF_Seq2SeqMetricsCallback[source]

HF_Seq2SeqMetricsCallback(custom_metrics:dict=None, calc_every:str='epoch', ignore_token_id=-100, text_gen_kwargs:dict={}, **kwargs) :: Callback

A callback that adds seq2seq metrics

Parameters:

  • custom_metrics : <class 'dict'>, optional

    A dictionary of seq2seq metrics we want to use. See below and the various task specific seq2seq docs for examples of how to configure this per task

  • calc_every : <class 'str'>, optional

    Calculation of these metrics requires text generation, which is expensive. You can choose to calculate these metrics on every 'epoch', 'other_epoch', or 'last_epoch' instead (default: 'epoch')

  • ignore_token_id : <class 'int'>, optional

    The token ID that should be ignored when calculating the loss

  • text_gen_kwargs : <class 'dict'>, optional

    Any keyword arguments to pass to the `hf_model.generate` method

  • kwargs : <class 'inspect._empty'>

We add a custom param splitter to give us a bit more depth in applying discriminative learning rates for Seq2Seq tasks.

seq2seq_splitter[source]

seq2seq_splitter(m:PreTrainedModel, arch:str)

Custom param splitter for summarization models

Parameters:

  • m : <class 'transformers.modeling_utils.PreTrainedModel'>

    A Hugging Face model

  • arch : <class 'str'>

    The name of the architecture you are working with (e.g., bart, fsmt, pegasus, etc...)

seq2seq_metrics = {
    'rouge': {
        'compute_kwargs': {
            'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True
        }, 
        'returns': ["rouge1", "rouge2", "rougeL"] 
    },
    'bertscore': {
        'compute_kwargs': { 'lang': 'en' },
        'returns': ["precision", "recall", "f1"]
    }, 
    'bleu': { 'returns': "bleu" },
    'meteor': { 'returns': "meteor" },
    'sacrebleu': { 'returns': "score" }
}

model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics, calc_every='other_epoch')]

learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()

learn.unfreeze()
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
b = dls.one_batch()
preds = learn.model(b[0])

len(preds),preds['loss'].shape, preds['logits'].shape
(4, torch.Size([]), torch.Size([2, 69, 50264]))
b = dls.one_batch()
preds = learn.model(b[0])

len(preds),preds['loss'].shape, preds['logits'].shape
(4, torch.Size([]), torch.Size([2, 69, 50264]))
print(len(learn.opt.param_groups))
3
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
/home/wgilliam/miniconda3/envs/blurr/lib/python3.9/site-packages/fastai/callback/schedule.py:269: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "ro" (-> color='r'). The keyword argument will take precedence.
  ax.plot(val, idx, 'ro', label=nm, c=color)
SuggestedLRs(minimum=6.918309954926372e-05, steep=1.4454397387453355e-05, valley=5.248074739938602e-05, slide=2.511886486900039e-05)
learn.fit_one_cycle(3, lr_max=slice(9e-7, 9e-5), cbs=fit_cbs)
epoch train_loss valid_loss rouge1 rouge2 rougeL bertscore_precision bertscore_recall bertscore_f1 bleu meteor sacrebleu time
0 1.980793 1.820687 None None None None None None None None None 02:01
1 1.188804 1.828470 0.380084 0.160968 0.259009 0.876564 0.893166 0.884687 0.135892 0.300841 11.008758 04:25
2 0.439346 2.136937 0.383360 0.163631 0.263455 0.880336 0.892410 0.886238 0.146506 0.296035 11.916852 04:05

Showing results

Below we'll add in additional functionality to take advantage of Hugging Face's PreTrainedModel.generate model, which can be used to easily implement beam search, top-k/nucleous sampling, etc... so that we get more human sounding results.

test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
res = learn.blurr_predict(test_article)
print(hf_tokenizer.decode(res[0][0][0][:20]))
<s><s>
Gun 10 men armed with pistols and small machine guns raided a casino in Switzerland. made

That doesn't look much like a human-generated text. Let's use Hugging Face's PreTrainedModel.generate method to create something more human-like.

b = dls.valid.one_batch()

tfm = first_blurr_tfm(dls)

b_hf_tokenizer = tfm.hf_tokenizer
b_ignore_token_id = tfm.ignore_token_id

test_input_ids = b[0]['input_ids'][0].unsqueeze(0).to(learn.model.hf_model.device)
test_trg_ids = b[1][0].unsqueeze(0).to(learn.model.hf_model.device)
test_trg_ids = [ trg[trg != b_ignore_token_id] for trg in test_trg_ids ]

gen_text = learn.model.hf_model.generate(test_input_ids, num_beams=4, max_length=130, min_length=30)

print('=== Target ===')
print(f'{b_hf_tokenizer.decode(test_trg_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)}\n')

print('=== Prediction ===')
print(b_hf_tokenizer.decode(gen_text[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
=== Target ===
 Hotel guests who "go green" are happier with their stay.
Increasing water and energy costs are pushing hotels to cut costs wherever they can.
Many hotels find that guests don't mind using the same towels and sheets every night.
TripAdvisor will be adding a green label for hotels listed on its site.

=== Prediction ===
 Hotel guests may not be as environmentally conscious as they might seem.
Dan Condon uses a new towel every day when he travels for work and stays in a hotel.
Condon: "I could care less about rewards for environmentally conscious behavior unless it's miles"
Hotels can't convince eco-conscious guests to go green while traveling.

To make things even easier, for text generation tasks you can simply call the Learn.blurr_generate method, optionally passing in whatever text generation kwargs you wish, to accomplish the same as above.

outputs = learn.blurr_generate(test_article, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 Gunmen made off with hundreds of thousands of Swiss francs, police say .
Gunmen dressed in black clothes and black ski masks, split into two groups during raid on Grand Casino Basel .
One group failed to break into casino vault, but did rob cashier of money that was not secured .
There were about 600 people in the casino at the time of the robbery, police officer says .

=== Prediction 2 ===
 Gunmen made off with hundreds of thousands of Swiss francs, police say .
Gunmen dressed in black clothes and black ski masks, split into two groups during raid on Grand Casino Basel .
One group failed to break into casino vault, but did rob cashier of money that was not secured .
There were about 600 people in casino at the time of the robbery, police officer says .

=== Prediction 3 ===
 Gunmen made off with hundreds of thousands of Swiss francs, police say .
Gunmen dressed in black clothes and black ski masks, split into two groups during raid on Grand Casino Basel .
One group failed to break into casino vault, but did rob cashier of money that was not secured .
About 600 people were in the casino at the time of the robbery, police officer says .

Much nicer!!! Now, we can update our @typedispatched show_results to use this new method.

learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 Dan Condon believes in recycling. Just not when it comes to his hotel towels. Condon composts when he's at home in Boulder, Colorado. He eats local, organic and fair-trade food and drives a Honda CR-Z hybrid sports car. You might call him green. Except he's not so green when he travels for his work at an education nonprofit and stays in a hotel, which happens about 10 weeks per year. There, he uses a new towel every day. And don't try to bribe him with a drink or dessert coupon to get him to re Hotel guests who "go green" are happier with their stay.\nIncreasing water and energy costs are pushing hotels to cut costs wherever they can.\nMany hotels find that guests don't mind using the same towels and sheets every night.\nTripAdvisor will be a Hotel guests may not be as environmentally conscious as they might seem .\nDan Condon uses a new towel every day when he travels for work and stays in a hotel .\nCondon: "I could care less about rewards for environmentally conscious behavior unless it
1 (CNN Student News) -- March 23, 2010. Download PDF maps related to today's show:. • Haiti • China. Transcript. THIS IS A RUSH TRANSCRIPT. THIS COPY MAY NOT BE IN ITS FINAL FORM AND MAY BE UPDATED. CARL AZUZ, CNN STUDENT NEWS ANCHOR: Happy birthday, Roger Bannister -- first man to run the mile in less than four minutes. In more than twice that time, you'll be up to speed on today's headlines. I'm Carl Azuz. First Up: Health Care. AZUZ: First up, it's the biggest expansion of the United States he Find out what comes next after the passage of a health care reform bill.\nLearn about a proposal that would change how student loans are funded.\nFollow the steps that led to a showdown between China and Google.\nUse the Daily Discussion to help studen Consider the biggest expansion of the U.S. health care system in more than forty years .\nHear how some lawmakers are responding to the House health care bill's passage .\nMeet the first man to run the mile in less than four minutes .\nUse the Daily Di

Inference

export_fname = 'summarize_export'
learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')
inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_article)
[' Gunmen made off with hundreds of thousands of Swiss francs, police say .\nGunmen dressed in black clothes and black ski masks, split into two groups during raid on Grand Casino Basel .\nOne group failed to break into casino vault, but did rob cashier of money that was not secured .\nThere were about 600 people in the casino at the time of the robbery, police officer says .']

Summary

This module includes the fundamental bits to all Seq2Seq transformers training and inference.