This notebook demonstrates how we can use Blurr to tackle the General Language Understanding Evaluation(GLUE) benchmark tasks.
 
Here's what we're running with ...

torch: 1.9.0+cu102
fastai: 2.5.2
transformers: 4.10.0
Using GPU #1: GeForce GTX 1080 Ti

GLUE tasks

Abbr Name Task type Description Size Metrics
CoLA Corpus of Linguistic Acceptability Single-Sentence Task Predict whether a sequence is a grammatical English sentence 8.5k Matthews corr.
SST-2 Stanford Sentiment Treebank Single-Sentence Task Predict the sentiment of a given sentence 67k Accuracy
MRPC Microsoft Research Paraphrase Corpus Similarity and Paraphrase Tasks Predict whether two sentences are semantically equivalent 3.7k F1/Accuracy
SST-B Semantic Textual Similarity Benchmark Similarity and Paraphrase Tasks Predict the similarity score for two sentences on a scale from 1 to 5 7k Pearson/Spearman corr.
QQP Quora question pair Similarity and Paraphrase Tasks Predict if two questions are a paraphrase of one another 364k F1/Accuracy
MNLI Mulit-Genre Natural Language Inference Inference Tasks Predict whether the premise entails, contradicts or is neutral to the hypothesis 393k Accuracy
QNLI Stanford Question Answering Dataset Inference Tasks Predict whether the context sentence contains the answer to the question 105k Accuracy
RTE Recognize Textual Entailment Inference Tasks Predict whether one sentece entails another 2.5k Accuracy
WNLI Winograd Schema Challenge Inference Tasks Predict if the sentence with the pronoun substituted is entailed by the original sentence 634 Accuracy

Define the task and hyperparmeters

We'll use the "distilroberta-base" checkpoint for this example, but if you want to try an architecture that returns token_type_ids for example, you can use something like bert-cased.

task = 'mrpc'
task_meta = glue_tasks[task]
train_ds_name = task_meta['dataset_names']["train"]
valid_ds_name = task_meta['dataset_names']["valid"]
test_ds_name = task_meta['dataset_names']["test"]

task_inputs =  task_meta['inputs']
task_target =  task_meta['target']
task_metrics = task_meta['metric_funcs']

pretrained_model_name = "distilroberta-base" # bert-base-cased | distilroberta-base

bsz = 16
val_bsz = bsz *2

Prepare the datasets

Let's start by building our DataBlock. We'll load the MRPC datset from huggingface's datasets library which will be cached after downloading via the load_dataset method. For more information on the datasets API, see the documentation here.

raw_datasets = load_dataset('glue', task) 
print(f'{raw_datasets}\n')
print(f'{raw_datasets[train_ds_name][0]}\n')
print(f'{raw_datasets[train_ds_name].features}\n')
Reusing dataset glue (/home/wgilliam/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1725
    })
})

{'idx': 0, 'label': 1, 'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .', 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .'}

{'sentence1': Value(dtype='string', id=None), 'sentence2': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], names_file=None, id=None), 'idx': Value(dtype='int32', id=None)}

There are a variety of ways we can preprocess the dataset for DataBlock consumption. For example, we could push the data into a DataFrame, add a boolean is_valid column, and use the ColSplitter method to define our train/validation splits like this:

raw_train_df = pd.DataFrame(raw_datasets[train_ds_name], columns=list(raw_datasets[train_ds_name].features.keys()))
raw_train_df['is_valid'] = False

raw_valid_df = pd.DataFrame(raw_datasets[valid_ds_name], columns=list(raw_datasets[train_ds_name].features.keys()))
raw_valid_df['is_valid'] = True

raw_df = pd.concat([raw_train_df, raw_valid_df])
print(len(raw_df))
raw_df.head()
4076
sentence1 sentence2 label idx is_valid
0 Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence . Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence . 1 0 False
1 Yucaipa owned Dominick 's before selling the chain to Safeway in 1998 for $ 2.5 billion . Yucaipa bought Dominick 's in 1995 for $ 693 million and sold it to Safeway for $ 1.8 billion in 1998 . 0 1 False
2 They had published an advertisement on the Internet on June 10 , offering the cargo for sale , he added . On June 10 , the ship 's owners had published an advertisement on the Internet , offering the explosives for sale . 1 2 False
3 Around 0335 GMT , Tab shares were up 19 cents , or 4.4 % , at A $ 4.56 , having earlier set a record high of A $ 4.57 . Tab shares jumped 20 cents , or 4.6 % , to set a record closing high at A $ 4.57 . 0 3 False
4 The stock rose $ 2.11 , or about 11 percent , to close Friday at $ 21.51 on the New York Stock Exchange . PG & E Corp. shares jumped $ 1.63 or 8 percent to $ 21.03 on the New York Stock Exchange on Friday . 1 4 False

Another option is to capture the indexes for both train and validation sets, use the datasets concatenate_datasets to put them into a single dataset, and finally use the IndexSplitter method to define our train/validation splits as such:

n_train, n_valid = raw_datasets[train_ds_name].num_rows, raw_datasets[valid_ds_name].num_rows
train_idxs, valid_idxs = L(range(n_train)), L(range(n_train, n_train + n_valid))
raw_ds = concatenate_datasets([raw_datasets[train_ds_name], raw_datasets[valid_ds_name]])

Mid-level API

Prepare the huggingface objects

How many classes are we working with? Depending on your approach above, you can do one of the two approaches below.

n_lbls = raw_df[task_target].nunique(); n_lbls
2
n_lbls = len(set([item[task_target] for item in raw_ds])); n_lbls
2
model_cls = AutoModelForSequenceClassification

config = AutoConfig.from_pretrained(pretrained_model_name)
config.num_labels = n_lbls

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, 
                                                                  model_cls=model_cls, 
                                                                  config=config)

print(hf_arch)
print(type(hf_config))
print(type(hf_tokenizer))
print(type(hf_model))
roberta
<class 'transformers.models.roberta.configuration_roberta.RobertaConfig'>
<class 'transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast'>
<class 'transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification'>

Build the DataBlock

blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), CategoryBlock())

def get_x(r, attr): 
    return r[attr] if (isinstance(attr, str)) else tuple(r[inp] for inp in attr)
    
dblock = DataBlock(blocks=blocks, 
                   get_x=partial(get_x, attr=task_inputs), 
                   get_y=ItemGetter(task_target), 
                   splitter=IndexSplitter(valid_idxs))
dls = dblock.dataloaders(raw_ds, bs=bsz, val_bs=val_bsz)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([16, 103]), torch.Size([16]))
if ('token_type_ids' in b[0]):
    print([(hf_tokenizer.convert_ids_to_tokens(inp_id.item()), inp_id.item(), tt_id.item() )
           for inp_id, tt_id in zip (b[0]['input_ids'][0], b[0]['token_type_ids'][0]) 
           if inp_id != hf_tokenizer.pad_token_id])
dls.show_batch(dataloaders=dls, max_n=5)
text target
0 " In Iraq, " Sen. Pat Roberts, R-Kan., chairman of the intelligence committee, said on CNN's " Late Edition " Sunday, " we're now fighting an anti-guerrilla... effort. " " In Iraq, " Sen. Pat Roberts ( R-Kan. ), chairman of the intelligence committee, said on CNN's " Late Edition " yesterday, " we're now fighting an anti-guerrilla... effort. " 1
1 Regional utility Tohoku Electric Power Co Inc said an 825,000-kilowatt ( kW ) nuclear reactor, the Onagawa No.3 unit near Sendai automatically shut down due to the quake. Japan's Tohoku Electric Power Co Inc said an 825,000-kilowatt ( kW ) nuclear reactor, the Onagawa No.3 unit in northern Japan, automatically shut down due to the quake. 1
2 Drewes and a friend were playing a game of " ding-dong-ditch " -- ringing doorbells and running away -- in the Woodbury neighborhood in suburban Boca Raton. Drewes and his friend were pulling a mischievous, late-night game of " ding-dong-ditch " knocking on doors or ringing doorbells and running in the Woodbury neighborhood in suburban Boca Raton. 1
3 The updated 64-bit operating system, Windows XP 64-Bit Edition for 64-Bit Extended Systems, will run natively on AMD Athlon 64 processor-powered desktops and AMD Opteron processor-powered workstations. Windows XP 64-bit Edition for 64-Bit Extended Systems will support AMD64 technology, running natively on AMD Athlon 64 powered desktops and AMD Opteron processor-powered workstations. 1
4 The $ 19.50-a-share bid, comes two days after PeopleSoft revised its bid for smaller rival J.D. Edwards & Co. JDEC.O to include cash as well as stock. Oracle's $ 19.50-a-share bid comes two days after PeopleSoft added cash to its original all-share deal with smaller rival J.D. Edwards & Co. JDEC.O. 1

Train

With our DataLoaders built, we can now build our Learner and train. We'll use mixed precision so we can train with bigger batches

model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(),
                metrics=task_metrics,
                cbs=[HF_BaseModelCallback],
                splitter=hf_splitter).to_fp16()

learn.freeze()
learn.summary()
HF_BaseModelWrapper (Input shape: 16)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     16 x 103 x 768      
Embedding                                 38603520   False     
Embedding                                 394752     False     
Embedding                                 768        False     
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     False     
Linear                                    590592     False     
Linear                                    590592     False     
Dropout                                                        
Linear                                    590592     False     
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     16 x 103 x 3072     
Linear                                    2362368    False     
____________________________________________________________________________
                     16 x 103 x 768      
Linear                                    2360064    False     
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     False     
Linear                                    590592     False     
Linear                                    590592     False     
Dropout                                                        
Linear                                    590592     False     
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     16 x 103 x 3072     
Linear                                    2362368    False     
____________________________________________________________________________
                     16 x 103 x 768      
Linear                                    2360064    False     
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     False     
Linear                                    590592     False     
Linear                                    590592     False     
Dropout                                                        
Linear                                    590592     False     
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     16 x 103 x 3072     
Linear                                    2362368    False     
____________________________________________________________________________
                     16 x 103 x 768      
Linear                                    2360064    False     
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     False     
Linear                                    590592     False     
Linear                                    590592     False     
Dropout                                                        
Linear                                    590592     False     
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     16 x 103 x 3072     
Linear                                    2362368    False     
____________________________________________________________________________
                     16 x 103 x 768      
Linear                                    2360064    False     
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     False     
Linear                                    590592     False     
Linear                                    590592     False     
Dropout                                                        
Linear                                    590592     False     
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     16 x 103 x 3072     
Linear                                    2362368    False     
____________________________________________________________________________
                     16 x 103 x 768      
Linear                                    2360064    False     
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     False     
Linear                                    590592     False     
Linear                                    590592     False     
Dropout                                                        
Linear                                    590592     False     
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     16 x 103 x 3072     
Linear                                    2362368    False     
____________________________________________________________________________
                     16 x 103 x 768      
Linear                                    2360064    False     
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Dropout                                                        
____________________________________________________________________________
                     16 x 2              
Linear                                    1538       True      
____________________________________________________________________________

Total params: 82,119,938
Total trainable params: 612,098
Total non-trainable params: 81,507,840

Optimizer used: functools.partial(<function Adam at 0x7f6ea70423a0>)
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - HF_BaseModelCallback
  - MixedPrecision
  - Recorder
  - ProgressCallback
preds = model(b[0])
preds.logits.shape, preds
(torch.Size([16, 2]),
 SequenceClassifierOutput(loss=None, logits=tensor([[ 0.1265, -0.1490],
         [ 0.1444, -0.1744],
         [ 0.1395, -0.1919],
         [ 0.1403, -0.1818],
         [ 0.1515, -0.1665],
         [ 0.1405, -0.1684],
         [ 0.1455, -0.1795],
         [ 0.1561, -0.1865],
         [ 0.1496, -0.1692],
         [ 0.1450, -0.1842],
         [ 0.1440, -0.1724],
         [ 0.1491, -0.1704],
         [ 0.1365, -0.1640],
         [ 0.1521, -0.1618],
         [ 0.1394, -0.1710],
         [ 0.1453, -0.1753]], device='cuda:1', grad_fn=<AddmmBackward>), hidden_states=None, attentions=None))
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
/home/wgilliam/miniconda3/envs/blurr/lib/python3.9/site-packages/fastai/callback/schedule.py:269: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "ro" (-> color='r'). The keyword argument will take precedence.
  ax.plot(val, idx, 'ro', label=nm, c=color)
SuggestedLRs(minimum=0.0019054606556892395, steep=6.309573450380412e-07, valley=0.0003981071640737355, slide=0.0020892962347716093)
learn.fit_one_cycle(1, lr_max=2e-3)
epoch train_loss valid_loss f1_score accuracy time
0 0.525879 0.487189 0.831650 0.754902 00:10
learn.unfreeze()
learn.lr_find(start_lr=1e-12, end_lr=2e-3, suggest_funcs=[minimum, steep, valley, slide])
/home/wgilliam/miniconda3/envs/blurr/lib/python3.9/site-packages/fastai/callback/schedule.py:269: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "ro" (-> color='r'). The keyword argument will take precedence.
  ax.plot(val, idx, 'ro', label=nm, c=color)
SuggestedLRs(minimum=1.53075743583031e-05, steep=1.6185790555067747e-11, valley=4.015422746306285e-06, slide=1.7980568372877315e-05)
learn.fit_one_cycle(2, lr_max=slice(2e-5, 2e-4))
epoch train_loss valid_loss f1_score accuracy time
0 0.478144 0.356452 0.897666 0.860294 00:18
1 0.259431 0.302929 0.910035 0.872549 00:18
learn.show_results(learner=learn, max_n=5)
text target prediction
0 He said the foodservice pie business doesn 't fit the company's long-term growth strategy. " The foodservice pie business does not fit our long-term growth strategy. 1 1
1 But I would rather be talking about high standards than low standards. " " I would rather be talking about positive numbers rather than negative. 1 0
2 Gasps could be heard in the courtroom when the photo was displayed. Gasps could be heard as the photo was projected onto the screen. 1 1
3 Mr. Rowland attended a party in South Windsor for the families of Connecticut National Guard soldiers called to active duty. Rowland was making an appearance at a holiday party for families of Connecticut National Guard soldiers assigned to duty in Iraq and Afghanistan. 1 1
4 Let me just say this : the evidence that we have of weapons of mass destruction was evidence drawn up and accepted by the joint intelligence community. " The evidence that we had of weapons of mass destruction was drawn up and accepted by the Joint Intelligence Committee, " he said. 1 1

Evaluate

How did we do?

val_res = learn.validate()
val_res_d = { 'loss': val_res[0]}
for idx, m in enumerate(learn.metrics):
    val_res_d[m.name] = val_res[idx+1]
    
val_res_d
{'loss': 0.30292922258377075,
 'f1_score': 0.9100346020761245,
 'accuracy': 0.8725489974021912}
preds, targs, losses = learn.get_preds(with_loss=True)
print(preds.shape, targs.shape, losses.shape)
print(losses.mean(), accuracy(preds, targs))
torch.Size([408, 2]) torch.Size([408]) torch.Size([408])
TensorBase(0.3029) TensorBase(0.8725)

Inference

Let's do item inference on an example from our test dataset

raw_test_df = pd.DataFrame(raw_datasets[test_ds_name], columns=list(raw_datasets[test_ds_name].features.keys()))
raw_test_df.head(10)
sentence1 sentence2 label idx
0 PCCW 's chief operating officer , Mike Butcher , and Alex Arena , the chief financial officer , will report directly to Mr So . Current Chief Operating Officer Mike Butcher and Group Chief Financial Officer Alex Arena will report to So . 1 0
1 The world 's two largest automakers said their U.S. sales declined more than predicted last month as a late summer sales frenzy caused more of an industry backlash than expected . Domestic sales at both GM and No. 2 Ford Motor Co. declined more than predicted as a late summer sales frenzy prompted a larger-than-expected industry backlash . 1 1
2 According to the federal Centers for Disease Control and Prevention ( news - web sites ) , there were 19 reported cases of measles in the United States in 2002 . The Centers for Disease Control and Prevention said there were 19 reported cases of measles in the United States in 2002 . 1 2
3 A tropical storm rapidly developed in the Gulf of Mexico Sunday and was expected to hit somewhere along the Texas or Louisiana coasts by Monday night . A tropical storm rapidly developed in the Gulf of Mexico on Sunday and could have hurricane-force winds when it hits land somewhere along the Louisiana coast Monday night . 0 3
4 The company didn 't detail the costs of the replacement and repairs . But company officials expect the costs of the replacement work to run into the millions of dollars . 0 4
5 The settling companies would also assign their possible claims against the underwriters to the investor plaintiffs , he added . Under the agreement , the settling companies will also assign their potential claims against the underwriters to the investors , he added . 1 5
6 Air Commodore Quaife said the Hornets remained on three-minute alert throughout the operation . Air Commodore John Quaife said the security operation was unprecedented . 0 6
7 A Washington County man may have the countys first human case of West Nile virus , the health department said Friday . The countys first and only human case of West Nile this year was confirmed by health officials on Sept . 8 . 1 7
8 Moseley and a senior aide delivered their summary assessments to about 300 American and allied military officers on Thursday . General Moseley and a senior aide presented their assessments at an internal briefing for American and allied military officers at Nellis Air Force Base in Nevada on Thursday . 1 8
9 The broader Standard & Poor 's 500 Index < .SPX > was 0.46 points lower , or 0.05 percent , at 997.02 . The technology-laced Nasdaq Composite Index .IXIC was up 7.42 points , or 0.45 percent , at 1,653.44 . 0 9
learn.blurr_predict(raw_test_df.iloc[9].to_dict())
[(('0',), (#1) [tensor(0)], (#1) [tensor([0.9331, 0.0669])])]

Let's do batch inference on the entire test dataset

test_dl = dls.test_dl(raw_datasets[test_ds_name])
preds = learn.get_preds(dl=test_dl)
preds
(tensor([[0.0105, 0.9895],
         [0.0093, 0.9907],
         [0.0054, 0.9946],
         ...,
         [0.0583, 0.9417],
         [0.0104, 0.9896],
         [0.0664, 0.9336]]),
 None)

High-level API

With the high-level API, we can create our DataBlock, DataLoaders, and Blearner in one line of code

dl_kwargs = {'bs': bsz, 'val_bs': val_bsz}
learn_kwargs = { 'metrics': task_metrics }

learn = BlearnerForSequenceClassification.from_dataframe(raw_df, pretrained_model_name, 
                                                         text_attr=task_inputs, label_attr=task_target,
                                                         dl_kwargs=dl_kwargs, learner_kwargs=learn_kwargs)
learn.fit_one_cycle(1, lr_max=2e-3)
epoch train_loss valid_loss f1_score accuracy time
0 0.518781 0.466930 0.860390 0.789216 00:09
learn.show_results(learner=learn, max_n=5)
text target prediction
0 He said the foodservice pie business doesn 't fit the company's long-term growth strategy. " The foodservice pie business does not fit our long-term growth strategy. 1 1
1 But Virgin wants to operate Concorde on routes to New York, Barbados and Dubai. Branson said that his preference would be to operate a fully commercial service on routes to New York, Barbados and Dubai. 1 1
2 On the stand Wednesday, she said she was referring only to the kissing. On the stand Wednesday, she testified that she was referring to the kissing before the alleged rape. 0 1
3 The poll had a margin of error of plus or minus 2 percentage points. It had a margin of sampling error of plus or minus four percentage points and was conducted Thursday through Saturday. 0 0
4 Some of the computers also are used to send spam e-mail messages to drum up traffic to the sites. Some are also used to send spam e-mail messages to boost traffic to the sites. 1 1

Summary

The general flow of this notebook was inspired by Zach Mueller's "Text Classification with Transformers" example that can be found in the wonderful Walk With Fastai docs. Take a look there for another approach to working with fast.ai and Hugging Face on GLUE tasks.