blurr
  • Getting Started
  • Resources
    • fastai x Hugging Face Study Group
    • Hugging Face Course
    • fast.ai (docs)
    • transformers (docs)
  • Help
    • Report an Issue

Modeling

  • Overview
    • Getting Started
    • callbacks
    • utils
  • Text
    • Sequence Classification
      • Data
      • Modeling
    • Token Classification
      • Data
      • Modeling
    • Question & Answering
      • Data
      • Modeling
    • Language Modeling
      • Data
      • Modeling
    • Seq2Seq: Core
      • Data
      • Modeling
    • Seq2Seq: Summarization
      • Data
      • Modeling
    • Seq2Seq: Translation
      • Data
      • Modeling
    • callbacks
    • utils
  • Examples
    • Using the high-level Blurr API
    • GLUE classification tasks
    • Using the Low-level fastai API
    • Multi-label classification
    • Causal Language Modeling with GPT-2

On this page

  • Mid-level API
    • blurr_splitter
    • BaseModelWrapper
    • BaseModelCallback
    • Example
      • Training
      • Showing results
      • Prediction
    • Learner.blurr_predict
    • Learner.blurr_predict
      • Text generation
    • Learner.blurr_generate
    • Learner.blurr_generate
      • Inference
  • High-level API
    • Blearner
    • BlearnerForSequenceClassification
    • Examples
      • Using Mid-level API building blocks
      • Using Low-level API building blocks
  • Tests

Report an issue

Modeling

The text.modeling.core module contains core custom models, loss functions, and a default layer group splitter for use in applying discriminiative learning rates to your Hugging Face models trained via fastai

Mid-level API

Base splitter, model wrapper, and model callback


source

blurr_splitter

 blurr_splitter (m:fastai.torch_core.Module)

Splits the Hugging Face model based on various model architecture conventions


source

BaseModelWrapper

 BaseModelWrapper (hf_model:transformers.modeling_utils.PreTrainedModel,
                   output_hidden_states:bool=False,
                   output_attentions:bool=False, hf_model_kwargs={})

Same as nn.Module, but no need for subclasses to call super().__init__

Type Default Details
hf_model PreTrainedModel Your Hugging Face model
output_hidden_states bool False If True, hidden_states will be returned and accessed from Learner
output_attentions bool False If True, attentions will be returned and accessed from Learner
hf_model_kwargs dict {} Any additional keyword arguments you want passed into your models forward method

Note that BaseModelWrapper includes some nifty code for just passing in the things your model needs, as not all transformer architectures require/use the same information.


source

BaseModelCallback

 BaseModelCallback (base_model_wrapper_kwargs:dict={})

Basic class handling tweaks of the training loop by changing a Learner in various events

Type Default Details
base_model_wrapper_kwargs dict {} Additional keyword arguments passed to BaseModelWrapper

We use a Callback for handling the ModelOutput returned by Hugging Face transformers. It allows us to associate anything we want from that object to our Learner.

Note that your Learner’s loss will be set for you only if the Hugging Face model returns one and you are using the PreCalculatedLoss loss function.

Also note that anything else you asked the model to return (for example, last hidden state, etc..) will be available for you via the blurr_model_outputs property attached to your Learner. For example, assuming you are using BERT for a classification task … if you have told your BaseModelWrapper instance to return attentions, you’d be able to access them via learn.blurr_model_outputs['attentions'].

Example

Below demonstrates how to setup your pipeline for a sequence classification task (e.g., a model that requires a single text input) using the mid, high, and low-level API

raw_datasets = load_dataset("imdb", split=["train", "test"])
raw_datasets[0] = raw_datasets[0].add_column("is_valid", [False] * len(raw_datasets[0]))
raw_datasets[1] = raw_datasets[1].add_column("is_valid", [True] * len(raw_datasets[1]))

final_ds = concatenate_datasets([raw_datasets[0].shuffle().select(range(1000)), raw_datasets[1].shuffle().select(range(200))])
imdb_df = pd.DataFrame(final_ds)
imdb_df.head()
Reusing dataset imdb (/home/wgilliam/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
text label is_valid
0 I wanted to see this movie ever since it was first advertised on TV. I went to Tinsel Town to see it Last Night at 7:40. I regret the day that wasted my ticket on this trash when I could of saw something better. The beginning was all a bunch sex trash and cliches. They exaggerated the way love works in reality. All of the girls were stereo types. The boyfriend was too stupid for his own age. The passing gases that the pregnant girl kept having barely got any laughs. The bank robbery was completely boring with gags that have been used in other movies. Their getaway car was an old beat up Ch... 0 False
1 As a huge baseball fan, my scrutiny of this film is how realistic it appears. Dennis Quaid had all of the right moves and stances of a major league pitcher. It is a fantastic true story told with just a little too much "Disney" for my taste. 1 False
2 This ranks as one of the worst movies I've seen in years. Besides Cuba and Angie, the acting is actually embarrassing. Wasn't Archer once a decent actress? What happened to her? The action is decent but completely implausible. The make up is so bad it's worth mentioning. I mean, who ever even thinks about the makeup in a contemporary feature film. Someone should tell the make up artist, and the DOP that you're not supposed to actually see it. The ending is a massive disappointment - along the lines of "and then they realized it was all a dream"<br /><br />Don't waste your time or your mone... 0 False
3 For those of us Baby Boomers who arrived too late on the scene to appreciate James Dean et. al., Martin Sheen showed us The Way in this great feature.<br /><br />The premise is easy enough: cool hood meets small town sheriff and All-Hell ensues, but the nuts and bolts of this movie enthrall the car nut in all of us. <br /><br />No, this isn't Casablanca, nor is it great Literature, but it IS a serious movie about cars, rebellion, and the genius that is Martin Sheen.<br /><br />Enjoy this and appreciate it for what it is, and for what Martin will become. I loved this movie growing up as a t... 1 False
4 Similar to "On the Town," this musical about sailors on shore leave falls short of the later classic in terms of pacing and the quality of the songs, but it has its own charms. Kelly has three fabulous dance routines: one with Jerry the cartoon mouse of "Tom and Jerry" fame, one with a little girl, and a fantasy sequence where he is a Spanish lover determined to reach his lady on a high balcony. Sinatra, playing Kelly's shy, inexperienced buddy, and Grayson, the woman who serves as the love interest for both men, do most of the singing. Iturbi provides some fine piano playing. At nearly tw... 1 False
labels = raw_datasets[0].features["label"].names
labels
['neg', 'pos']
model_cls = AutoModelForSequenceClassification
hf_logging.set_verbosity_error()

pretrained_model_name = "distilroberta-base"  # "distilbert-base-uncased" "bert-base-uncased"
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
# single input
set_seed()
blocks = (TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model, batch_tokenize_kwargs={"labels": labels}), CategoryBlock)
dblock = DataBlock(blocks=blocks, get_x=ColReader("text"), get_y=ColReader("label"), splitter=RandomSplitter(seed=42))
dls = dblock.dataloaders(imdb_df, bs=4)
dls.show_batch(dataloaders=dls, max_n=2, trunc_at=500)
text target
0 My Comments for VIVAH :- Its a charming, idealistic love story starring Shahid Kapoor and Amrita Rao. The film takes us back to small pleasures like the bride and bridegroom's families sleeping on the floor, playing games together, their friendly banter and mutual respect. Vivah is about the sanctity of marriage and the importance of commitment between two individuals. Yes, the central romance is naively visualized. But the sneaked-in romantic moments between the to-be-married couple and their pos
1 WWE Armageddon, December 17, 2006 -- Live from Richmond Coliseum, Richmond, VA <br /><br />Kane vs. MVP in an Inferno match: So this is the fourth ever inferno match in the WWE and it is Kane vs. MVP (wonder why was it the first match on the card). I only viewed the ending parts where Kane sets MVP's ass on fire as they're on the apron and then MVP is running around the arena while yelling – eventually the refs put out the fire with a fire extinguisher as MVP sprawls around the entrance ramp. F pos

Training

.to_fp16() requires a GPU so had to remove for tests to run on github. Let’s check that we can get predictions.

set_seed()

model = BaseModelWrapper(hf_model)

learn = Learner(
    dls,
    model,
    opt_func=partial(OptimWrapper, opt=torch.optim.Adam),
    loss_func=PreCalculatedCrossEntropyLoss(),  # CrossEntropyLossFlat(),
    metrics=[accuracy],
    cbs=[BaseModelCallback],
    splitter=blurr_splitter,
)

learn.freeze()
learn.summary()
print(len(learn.opt.param_groups))
3
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
SuggestedLRs(minimum=0.00012022644514217973, steep=0.0063095735386013985, valley=0.0003311311302240938, slide=0.0020892962347716093)

set_seed()
learn.fit_one_cycle(1, lr_max=1e-3)
epoch train_loss valid_loss accuracy time
0 0.283207 0.279155 0.883333 00:13

Showing results

And here we create a @typedispatched implementation of Learner.show_results.

learn.show_results(learner=learn, max_n=2, trunc_at=500)
text target prediction
0 Haha, what a great little movie! Wayne Crawford strikes again, or rather this was his first big strike, a deliriously entertaining little ball of manic kitsch energy masquerading as a psycho killer movie. It's actually a **brilliant** satire on post-hippie American culture in flyover country, though the movie was actually filmed independently in Miami. It defies any kind of studio oriented convention or plot device that I can think of: SOMETIMES AUNT MARTHA DOES DREADFUL THINGS may not be a ver pos pos
1 Most would agree that the character of Wolverine is one of the most intriguing characters in comic book history. I'm no Marvel expert, but I did grow up with the adventures of the X-Men and definitely approved of Hugh Jackman's now widely known portrayal of the scruffy Logan. I enjoyed the first X-Men, found the sequel too heavy and messy and liked the third one as comic book entertainment. All through the three movies, I probably enjoyed Jackman more than anything else. I figured the idea of m neg neg
learn.unfreeze()
set_seed()
learn.fit_one_cycle(2, lr_max=slice(1e-7, 1e-4))
epoch train_loss valid_loss accuracy time
0 0.226185 0.279885 0.879167 00:21
1 0.201225 0.233090 0.900000 00:21
learn.recorder.plot_loss()

learn.show_results(learner=learn, max_n=2, trunc_at=500)
text target prediction
0 Haha, what a great little movie! Wayne Crawford strikes again, or rather this was his first big strike, a deliriously entertaining little ball of manic kitsch energy masquerading as a psycho killer movie. It's actually a **brilliant** satire on post-hippie American culture in flyover country, though the movie was actually filmed independently in Miami. It defies any kind of studio oriented convention or plot device that I can think of: SOMETIMES AUNT MARTHA DOES DREADFUL THINGS may not be a ver pos pos
1 Most would agree that the character of Wolverine is one of the most intriguing characters in comic book history. I'm no Marvel expert, but I did grow up with the adventures of the X-Men and definitely approved of Hugh Jackman's now widely known portrayal of the scruffy Logan. I enjoyed the first X-Men, found the sequel too heavy and messy and liked the third one as comic book entertainment. All through the three movies, I probably enjoyed Jackman more than anything else. I figured the idea of m neg neg

Prediction

We need to replace fastai’s Learner.predict method with the one above which is able to work with inputs that are represented by multiple tensors included in a dictionary.


source

Learner.blurr_predict

 Learner.blurr_predict (items, rm_type_tfms=None)

source

Learner.blurr_predict

 Learner.blurr_predict (items, rm_type_tfms=None)
learn.blurr_predict("I really liked the movie")
[{'label': 'pos',
  'score': 0.9718672633171082,
  'class_index': 1,
  'class_labels': ['neg', 'pos'],
  'probs': [0.028132732957601547, 0.9718672633171082]}]
learn.blurr_predict("Acting was so bad it was almost funny.")
[{'label': 'neg',
  'score': 0.665842592716217,
  'class_index': 0,
  'class_labels': ['neg', 'pos'],
  'probs': [0.665842592716217, 0.33415740728378296]}]
learn.blurr_predict(["I really liked the movie", "I really hated the movie"])
[{'label': 'pos',
  'score': 0.9718672633171082,
  'class_index': 1,
  'class_labels': ['neg', 'pos'],
  'probs': [0.028132745996117592, 0.9718672633171082]},
 {'label': 'pos',
  'score': 0.5788970589637756,
  'class_index': 1,
  'class_labels': ['neg', 'pos'],
  'probs': [0.42110294103622437, 0.5788970589637756]}]

Text generation

Though not useful in sequence classification, we will also add a blurr_generate method to Learner that uses Hugging Face’s PreTrainedModel.generate for text generation tasks.

For the full list of arguments you can pass in see here. You can also check out their “How To Generate” notebook for more information about how it all works.


source

Learner.blurr_generate

 Learner.blurr_generate (items, key='generated_texts', **kwargs)

Uses the built-in generate method to generate the text (see here for a list of arguments you can pass in)


source

Learner.blurr_generate

 Learner.blurr_generate (items, key='generated_texts', **kwargs)

Uses the built-in generate method to generate the text (see here for a list of arguments you can pass in)

Inference

Using fast.ai Learner.export and load_learner

export_fname = "seq_class_learn_export"
learn.export(fname=f"{export_fname}.pkl")
inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_predict("This movie should not be seen by anyone!!!!")
[{'label': 'neg',
  'score': 0.8900553584098816,
  'class_index': 0,
  'class_labels': ['neg', 'pos'],
  'probs': [0.8900553584098816, 0.1099446639418602]}]

High-level API

model_cls = AutoModelForSequenceClassification
hf_logging.set_verbosity_error()

pretrained_model_name = "distilroberta-base"  # "distilbert-base-uncased" "bert-base-uncased"
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)

dls = dblock.dataloaders(imdb_df, bs=4)

source

Blearner

 Blearner (dls:fastai.data.core.DataLoaders,
           hf_model:transformers.modeling_utils.PreTrainedModel,
           base_model_cb:__main__.BaseModelCallback=<class
           '__main__.BaseModelCallback'>, loss_func:callable|None=None,
           opt_func=<function Adam>, lr=0.001, splitter:callable=<function
           trainable_params>, cbs=None, metrics=None, path=None,
           model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,
           moms=(0.95, 0.85, 0.95), default_cbs:bool=True)

Group together a model, some dls and a loss_func to handle training

Type Default Details
dls DataLoaders containing data for each dataset needed for model
hf_model PreTrainedModel Your pretrained Hugging Face transformer
base_model_cb BaseModelCallback BaseModelCallback Your BaseModelCallback
Returns Learner

Instead of constructing our low-level Learner, we can use the Blearner class which provides sensible defaults for training

learn = Blearner(dls, hf_model, metrics=[accuracy])
learn.fit_one_cycle(1, lr_max=1e-3)
epoch train_loss valid_loss accuracy time
0 0.253714 0.305421 0.866667 00:13
learn.show_results(learner=learn, max_n=2, trunc_at=500)
text target prediction
0 Haha, what a great little movie! Wayne Crawford strikes again, or rather this was his first big strike, a deliriously entertaining little ball of manic kitsch energy masquerading as a psycho killer movie. It's actually a **brilliant** satire on post-hippie American culture in flyover country, though the movie was actually filmed independently in Miami. It defies any kind of studio oriented convention or plot device that I can think of: SOMETIMES AUNT MARTHA DOES DREADFUL THINGS may not be a ver pos pos
1 Most would agree that the character of Wolverine is one of the most intriguing characters in comic book history. I'm no Marvel expert, but I did grow up with the adventures of the X-Men and definitely approved of Hugh Jackman's now widely known portrayal of the scruffy Logan. I enjoyed the first X-Men, found the sequel too heavy and messy and liked the third one as comic book entertainment. All through the three movies, I probably enjoyed Jackman more than anything else. I figured the idea of m neg neg
learn.blurr_predict("This was a really good movie")
[{'label': 'pos',
  'score': 0.9749817848205566,
  'class_index': 1,
  'class_labels': ['neg', 'pos'],
  'probs': [0.02501819096505642, 0.9749817848205566]}]
learn.export(fname=f"{export_fname}.pkl")
inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_predict("This movie should not be seen by anyone!!!!")
[{'label': 'neg',
  'score': 0.7373340129852295,
  'class_index': 0,
  'class_labels': ['neg', 'pos'],
  'probs': [0.7373340129852295, 0.2626659572124481]}]

source

BlearnerForSequenceClassification

 BlearnerForSequenceClassification (dls:fastai.data.core.DataLoaders,
                                    hf_model:transformers.modeling_utils.P
                                    reTrainedModel, base_model_cb:__main__
                                    .BaseModelCallback=<class
                                    '__main__.BaseModelCallback'>,
                                    loss_func:callable|None=None,
                                    opt_func=<function Adam>, lr=0.001,
                                    splitter:callable=<function
                                    trainable_params>, cbs=None,
                                    metrics=None, path=None,
                                    model_dir='models', wd=None,
                                    wd_bn_bias=False, train_bn=True,
                                    moms=(0.95, 0.85, 0.95),
                                    default_cbs:bool=True)

Group together a model, some dls and a loss_func to handle training

Type Details
dls DataLoaders containing data for each dataset needed for model
hf_model PreTrainedModel

We also introduce a classification task specific Blearner that get you your DataBlock, DataLoaders, and BLearner in one line of code!

Examples

Using Mid-level API building blocks

learn = BlearnerForSequenceClassification.from_data(
    imdb_df, "distilroberta-base", text_attr="text", label_attr="label", dl_kwargs={"bs": 4}
)
learn.fit_one_cycle(1, lr_max=1e-3)
epoch train_loss valid_loss f1_score accuracy time
0 0.289666 0.233302 0.920188 0.915000 00:13
learn.show_results(learner=learn, max_n=2, trunc_at=500)
text target prediction
0 Watching Stranger Than Fiction director Marc Forster's The Kite Runner is the cinematic equivalent of eating your vegetables because this art-house epic rated PG-13 is good for your movie-going diet. No, this isn't the kind of movie that I like to slouch on the couch and eyeball at the end of a tough day. The Kite Runner isn't your typical mainstream movie designed to entertain you and make you forget about your troubles. First, no celebrity stars appear in it. Second, nothing is cut and dried, 1 1
1 As an ancient movie fan, I had heard much about the controversial movie CALIGULA assessed ambiguously as one of the most realistic epics by some and as one of the most disgusting porn movies by others. I decided to see it in the entire uncut version to evaluate it myself hoping to find something positive that would make justice to the many accusations towards the film. I sat down in my chair one autumn evening and started to watch. The beginning quotation from the New Testament shocked me a bit 0 0
learn.predict("This was a really good movie")
[{'label': '1',
  'score': 0.9277380108833313,
  'class_index': 1,
  'class_labels': [0, 1],
  'probs': [0.07226195186376572, 0.9277380108833313]}]
learn.export(fname=f"{export_fname}.pkl")
inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_predict("This movie should not be seen by anyone!!!!")
[{'label': '0',
  'score': 0.5971986651420593,
  'class_index': 0,
  'class_labels': [0, 1],
  'probs': [0.5971986651420593, 0.4028013050556183]}]

Using Low-level API building blocks

Thanks to the TextDataLoader, there isn’t really anything you have to do to use plain ol’ PyTorch or fast.ai Datasets and DataLoaders with Blurr. Let’s take a look at fine-tuning a model against Glue’s MRPC dataset …

Build your Hugging Face objects
model_cls = AutoModelForSequenceClassification

pretrained_model_name = "distilroberta-base"  # "distilbert-base-uncased" "bert-base-uncased"
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
Preprocess your data
from datasets import load_dataset
from blurr.text.data.core import preproc_hf_dataset

raw_datasets = load_dataset("glue", "mrpc")
Downloading and preparing dataset glue/mrpc (download: 1.43 MiB, generated: 1.43 MiB, post-processed: Unknown size, total: 2.85 MiB) to /home/runner/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...
Dataset glue downloaded and prepared to /home/runner/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.
Downloading builder script:   0%|          | 0.00/7.78k [00:00<?, ?B/s]Downloading builder script: 28.8kB [00:00, 17.6MB/s]                   
Downloading metadata:   0%|          | 0.00/4.47k [00:00<?, ?B/s]Downloading metadata: 28.7kB [00:00, 19.6MB/s]                   
Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]
Downloading data: 0.00B [00:00, ?B/s]Downloading data: 6.22kB [00:00, 4.47MB/s]
Downloading data files:  33%|###3      | 1/3 [00:00<00:01,  1.10it/s]
Downloading data: 0.00B [00:00, ?B/s]
Downloading data: 53.5kB [00:00, 338kB/s]
Downloading data: 280kB [00:00, 975kB/s] Downloading data: 1.05MB [00:00, 2.53MB/s]
Downloading data files:  67%|######6   | 2/3 [00:02<00:01,  1.12s/it]
Downloading data: 0.00B [00:00, ?B/s]
Downloading data: 45.2kB [00:00, 290kB/s]
Downloading data: 254kB [00:00, 856kB/s] Downloading data: 441kB [00:00, 1.31MB/s]
Downloading data files: 100%|##########| 3/3 [00:03<00:00,  1.14s/it]Downloading data files: 100%|##########| 3/3 [00:03<00:00,  1.12s/it]
Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]Generating train split:  47%|####7     | 1724/3668 [00:00<00:00, 17236.40 examples/s]Generating train split:  96%|#########5| 3506/3668 [00:00<00:00, 17575.26 examples/s]                                                                                     Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]                                                                           Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]                                                                        0%|          | 0/3 [00:00<?, ?it/s]100%|##########| 3/3 [00:00<00:00, 929.66it/s]
def tokenize_function(example):
    return hf_tokenizer(example["sentence1"], example["sentence2"], truncation=True)


tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
Loading cached processed dataset at /home/wgilliam/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2a4493b3e0d7eec3.arrow
Loading cached processed dataset at /home/wgilliam/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-429f84b2ba09bf45.arrow
Loading cached processed dataset at /home/wgilliam/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-38abcb2a8785e400.arrow
Build your DataLoaders
label_names = raw_datasets["train"].features["label"].names

trn_dl = TextDataLoader(
    tokenized_datasets["train"],
    hf_arch=hf_arch,
    hf_config=hf_config,
    hf_tokenizer=hf_tokenizer,
    hf_model=hf_model,
    preproccesing_func=preproc_hf_dataset,
    batch_decode_kwargs={"labels": label_names},
    shuffle=True,
    batch_size=8,
)

val_dl = TextDataLoader(
    tokenized_datasets["validation"],
    hf_arch=hf_arch,
    hf_config=hf_config,
    hf_tokenizer=hf_tokenizer,
    hf_model=hf_model,
    preproccesing_func=preproc_hf_dataset,
    batch_decode_kwargs={"labels": label_names},
    batch_size=16,
)

dls = DataLoaders(trn_dl, val_dl)
Define your Blearner
learn = BlearnerForSequenceClassification(dls, hf_model, loss_func=PreCalculatedCrossEntropyLoss())
Train
learn.lr_find()
SuggestedLRs(valley=7.585775892948732e-05)

learn.fit_one_cycle(1, lr_max=1e-3)
epoch train_loss valid_loss time
0 0.522506 0.483814 00:13
learn.unfreeze()
learn.fit_one_cycle(2, lr_max=slice(1e-8, 1e-6))
epoch train_loss valid_loss time
0 0.522194 0.483406 00:26
1 0.503598 0.482703 00:28
learn.show_results(learner=learn, max_n=2, trunc_at=500)
text target prediction
0 Spansion products are to be available from both AMD and Fujitsu, AMD said. Spansion Flash memory solutions are available worldwide from AMD and Fujitsu. equivalent equivalent
1 However, EPA officials would not confirm the 20 percent figure. Only in the past few weeks have officials settled on the 20 percent figure. not_equivalent not_equivalent

Tests

The tests below to ensure the core training code above works for all pretrained sequence classification models available in Hugging Face. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained classification models you are working with … and if any of your pretrained sequence classification models fail, please submit a github issue (or a PR if you’d like to fix it yourself)

arch tokenizer model result error
0 albert AlbertTokenizerFast AlbertForSequenceClassification PASSED
1 bart BartTokenizerFast BartForSequenceClassification PASSED
2 bert BertTokenizerFast BertForSequenceClassification PASSED
3 big_bird BigBirdTokenizerFast BigBirdForSequenceClassification PASSED
4 bigbird_pegasus PegasusTokenizerFast BigBirdPegasusForSequenceClassification PASSED
5 ctrl CTRLTokenizer CTRLForSequenceClassification PASSED
6 camembert CamembertTokenizerFast CamembertForSequenceClassification PASSED
7 canine CanineTokenizer CanineForSequenceClassification PASSED
8 convbert ConvBertTokenizerFast ConvBertForSequenceClassification PASSED
9 deberta DebertaTokenizerFast DebertaForSequenceClassification PASSED
10 deberta_v2 DebertaV2TokenizerFast DebertaV2ForSequenceClassification PASSED
11 distilbert DistilBertTokenizerFast DistilBertForSequenceClassification PASSED
12 electra ElectraTokenizerFast ElectraForSequenceClassification PASSED
13 flaubert FlaubertTokenizer FlaubertForSequenceClassification PASSED
14 funnel FunnelTokenizerFast FunnelForSequenceClassification PASSED
15 gpt2 GPT2TokenizerFast GPT2ForSequenceClassification PASSED
16 gptj GPT2TokenizerFast GPTJForSequenceClassification PASSED
17 gpt_neo GPT2TokenizerFast GPTNeoForSequenceClassification PASSED
18 ibert RobertaTokenizer IBertForSequenceClassification PASSED
19 led LEDTokenizerFast LEDForSequenceClassification PASSED
20 longformer LongformerTokenizerFast LongformerForSequenceClassification PASSED
21 mbart MBartTokenizerFast MBartForSequenceClassification PASSED
22 mpnet MPNetTokenizerFast MPNetForSequenceClassification PASSED
23 mobilebert MobileBertTokenizerFast MobileBertForSequenceClassification PASSED
24 openai OpenAIGPTTokenizerFast OpenAIGPTForSequenceClassification PASSED
25 rembert RemBertTokenizerFast RemBertForSequenceClassification PASSED
26 roformer RoFormerTokenizerFast RoFormerForSequenceClassification PASSED
27 roberta RobertaTokenizerFast RobertaForSequenceClassification PASSED
28 squeezebert SqueezeBertTokenizerFast SqueezeBertForSequenceClassification PASSED
29 transfo_xl TransfoXLTokenizer TransfoXLForSequenceClassification PASSED
30 xlm XLMTokenizer XLMForSequenceClassification PASSED
31 xlm_roberta XLMRobertaTokenizerFast XLMRobertaForSequenceClassification PASSED
32 xlnet XLNetTokenizerFast XLNetForSequenceClassification PASSED