import os, warnings
import torch
from transformers import *
from transformers.utils import logging as hf_logging
from fastai.text.all import *
from blurr.text.data.all import *
from blurr.text.modeling.all import *
Getting Started
Named after the fastest transformer (well, at least of the Autobots), BLURR provides both a comprehensive and extensible framework for training and deploying 🤗 huggingface transformer models with fastai >= 2.0.
Utilizing features like fastai’s new @typedispatch
and @patch
decorators, along with a simple class hiearchy, BLURR provides fastai developers with the ability to train and deploy transformers on a variety of tasks. It includes a high, mid, and low-level API that will allow developers to use much of it out-of-the-box or customize it as needed.
Supported Text/NLP Tasks: - Sequence Classification
- Token Classification
- Question Answering
- Summarization
- Tranlsation
- Language Modeling (Causal and Masked)
Supported Vision Tasks: - In progress
Supported Audio Tasks: - In progress
Install
You can now pip install blurr via pip install ohmeow-blurr
Or, even better as this library is under very active development, create an editable install like this:
git clone https://github.com/ohmeow/blurr.git
cd blurr
pip install -e ".[dev]"
How to use
Please check the documentation for more thorough examples of how to use this package.
The following two packages need to be installed for blurr to work:
1. fastai
2. Hugging Face transformers
Imports
"ignore")
warnings.simplefilter(
hf_logging.set_verbosity_error()
"TOKENIZERS_PARALLELISM"] = "false" os.environ[
Get your data
= untar_data(URLs.IMDB_SAMPLE)
path
= Path("models")
model_path = pd.read_csv(path / "texts.csv") imdb_df
Get n_labels
from data for config later
= len(imdb_df["label"].unique()) n_labels
Get your 🤗 objects
= AutoModelForSequenceClassification
model_cls
= "bert-base-uncased"
pretrained_model_name
= AutoConfig.from_pretrained(pretrained_model_name)
config = n_labels
config.num_labels
= get_hf_objects(
hf_arch, hf_config, hf_tokenizer, hf_model
pretrained_model_name,=model_cls,
model_cls=config
config )
Build your Data 🧱 and your DataLoaders
# single input
= (
blocks
TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model),
CategoryBlock
)= DataBlock(
dblock =blocks,
blocks=ColReader("text"),
get_x=ColReader("label"),
get_y=ColSplitter()
splitter
)
= dblock.dataloaders(imdb_df, bs=4) dls
=dls, max_n=2, trunc_at=250) dls.show_batch(dataloaders
text | target | |
---|---|---|
0 | raising victor vargas : a review < br / > < br / > you know, raising victor vargas is like sticking your hands into a big, steaming bowl of oatmeal. it's warm and gooey, but you're not sure if it feels right. try as i might, no matter how warm and go | negative |
1 | the shop around the corner is one of the sweetest and most feel - good romantic comedies ever made. there's just no getting around that, and it's hard to actually put one's feeling for this film into words. it's not one of those films that tries too | positive |
… and 🚂
= BaseModelWrapper(hf_model)
model
= Learner(
learn
dls,
hf_model,=partial(Adam, decouple_wd=True),
opt_func=CrossEntropyLossFlat(),
loss_func=[accuracy],
metrics=[BaseModelCallback],
cbs=blurr_splitter,
splitter
)
learn.freeze()
3, lr_max=1e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.628744 | 0.453862 | 0.780000 | 00:21 |
1 | 0.367063 | 0.294906 | 0.895000 | 00:22 |
2 | 0.238181 | 0.279067 | 0.900000 | 00:22 |
=learn, max_n=2, trunc_at=250) learn.show_results(learner
text | target | prediction | |
---|---|---|---|
0 | the trouble with the book, " memoirs of a geisha " is that it had japanese surfaces but underneath the surfaces it was all an american man's way of thinking. reading the book is like watching a magnificent ballet with great music, sets, and costumes | negative | negative |
1 | < br / > < br / > i'm sure things didn't exactly go the same way in the real life of homer hickam as they did in the film adaptation of his book, rocket boys, but the movie " october sky " ( an anagram of the book's title ) is good enough to stand al | positive | positive |
Using the high-level Blurr API
Using the high-level API we can reduce DataBlock, DataLoaders, and Learner creation into a single line of code.
Included in the high-level API is a general BLearner
class (pronouned “Blurrner”) that you can use with hand crafted DataLoaders, as well as, task specific BLearners like BLearnerForSequenceClassification
that will handle everything given your raw data sourced from a pandas DataFrame, CSV file, or list of dictionaries (for example a huggingface datasets dataset)
= BlearnerForSequenceClassification.from_data(
learn
imdb_df,
pretrained_model_name, ={"bs": 4}
dl_kwargs )
1, lr_max=1e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | f1_score | accuracy | time |
---|---|---|---|---|---|
0 | 0.530218 | 0.484683 | 0.789189 | 0.805000 | 00:22 |
=learn, max_n=2, trunc_at=250) learn.show_results(learner
text | target | prediction | |
---|---|---|---|
0 | the trouble with the book, " memoirs of a geisha " is that it had japanese surfaces but underneath the surfaces it was all an american man's way of thinking. reading the book is like watching a magnificent ballet with great music, sets, and costumes | negative | negative |
1 | < br / > < br / > i'm sure things didn't exactly go the same way in the real life of homer hickam as they did in the film adaptation of his book, rocket boys, but the movie " october sky " ( an anagram of the book's title ) is good enough to stand al | positive | positive |
⭐ Props
A word of gratitude to the following individuals, repos, and articles upon which much of this work is inspired from:
- The wonderful community that is the fastai forum and especially the tireless work of both Jeremy and Sylvain in building this amazing framework and place to learn deep learning.
- All the great tokenizers, transformers, docs, examples, and people over at huggingface
- FastHugs
- Fastai with 🤗Transformers (BERT, RoBERTa, XLNet, XLM, DistilBERT)
- Fastai integration with BERT: Multi-label text classification identifying toxicity in texts
- fastinference