blurr
  • Getting Started
  • Resources
    • fastai x Hugging Face Study Group
    • Hugging Face Course
    • fast.ai (docs)
    • transformers (docs)
  • Help
    • Report an Issue

Modeling

  • Overview
    • Getting Started
    • callbacks
    • utils
  • Text
    • Sequence Classification
      • Data
      • Modeling
    • Token Classification
      • Data
      • Modeling
    • Question & Answering
      • Data
      • Modeling
    • Language Modeling
      • Data
      • Modeling
    • Seq2Seq: Core
      • Data
      • Modeling
    • Seq2Seq: Summarization
      • Data
      • Modeling
    • Seq2Seq: Translation
      • Data
      • Modeling
    • callbacks
    • utils
  • Examples
    • Using the high-level Blurr API
    • GLUE classification tasks
    • Using the Low-level fastai API
    • Multi-label classification
    • Causal Language Modeling with GPT-2

On this page

  • Setup
  • Mid-level API
    • calculate_token_class_metrics
    • TokenClassMetricsCallback
    • Example
      • Training
      • Showing results
      • Prediction
    • TokenAggregationStrategies
    • Learner.blurr_predict_tokens
    • Learner.blurr_predict_tokens
      • Inference
  • High-level API
    • BlearnerForTokenClassification
    • Example
      • Define your Blearner
      • Train
      • Prediction
  • Tests

Report an issue

Modeling

The text.modeling.token_classification module contains custom models, loss functions, custom splitters, etc… for token classification tasks (e.g., Named entity recognition (NER), Part-of-speech tagging (POS), etc…). The objective of token classification is to predict the correct label for each token provided in the input. In the computer vision world, this is akin to what we do in segmentation tasks whereby we attempt to predict the class/label for each pixel in an image.

Setup

We’ll use a subset of conll2003 to demonstrate how to configure your BLURR code for token classification

Note: Make sure you set the config.num_labels attribute to the number of labels your model is predicting. The model will update its last layer accordingly as la transfer learning.

raw_datasets = load_dataset("conll2003")

labels = raw_datasets["train"].features["ner_tags"].feature.names
print(f"Labels: {labels}")

conll2003_df = pd.DataFrame(raw_datasets["train"])
conll2003_df.head()
Reusing dataset conll2003 (/home/wgilliam/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)
Labels: ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
id tokens pos_tags chunk_tags ner_tags
0 0 [EU, rejects, German, call, to, boycott, British, lamb, .] [22, 42, 16, 21, 35, 37, 16, 21, 7] [11, 21, 11, 12, 21, 22, 11, 12, 0] [3, 0, 7, 0, 0, 0, 7, 0, 0]
1 1 [Peter, Blackburn] [22, 22] [11, 12] [1, 2]
2 2 [BRUSSELS, 1996-08-22] [22, 11] [11, 12] [5, 0]
3 3 [The, European, Commission, said, on, Thursday, it, disagreed, with, German, advice, to, consumers, to, shun, British, lamb, until, scientists, determine, whether, mad, cow, disease, can, be, transmitted, to, sheep, .] [12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 35, 24, 35, 37, 16, 21, 15, 24, 41, 15, 16, 21, 21, 20, 37, 40, 35, 21, 7] [11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 13, 11, 21, 22, 11, 12, 17, 11, 21, 17, 11, 12, 12, 21, 22, 22, 13, 11, 0] [0, 3, 4, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
4 4 [Germany, 's, representative, to, the, European, Union, 's, veterinary, committee, Werner, Zwingmann, said, on, Wednesday, consumers, should, buy, sheepmeat, from, countries, other, than, Britain, until, the, scientific, advice, was, clearer, .] [22, 27, 21, 35, 12, 22, 22, 27, 16, 21, 22, 22, 38, 15, 22, 24, 20, 37, 21, 15, 24, 16, 15, 22, 15, 12, 16, 21, 38, 17, 7] [11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 12, 21, 13, 11, 12, 21, 22, 11, 13, 11, 1, 13, 11, 17, 11, 12, 12, 21, 1, 0] [5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0]
model_cls = AutoModelForTokenClassification
hf_logging.set_verbosity_error()

pretrained_model_name = "roberta-base"
config = AutoConfig.from_pretrained(pretrained_model_name)

config.num_labels = len(labels)
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls, config=config)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('roberta',
 transformers.models.roberta.configuration_roberta.RobertaConfig,
 transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast,
 transformers.models.roberta.modeling_roberta.RobertaForTokenClassification)
test_eq(hf_config.num_labels, len(labels))
batch_tok_tfm = TokenClassBatchTokenizeTransform(hf_arch, hf_config, hf_tokenizer, hf_model)
blocks = (TextBlock(batch_tokenize_tfm=batch_tok_tfm, input_return_type=TokenClassTextInput), TokenCategoryBlock(vocab=labels))

dblock = DataBlock(blocks=blocks, get_x=ColReader("tokens"), get_y=ColReader("ner_tags"), splitter=RandomSplitter())
dls = dblock.dataloaders(conll2003_df, bs=4)
b = dls.one_batch()
dls.show_batch(dataloaders=dls, max_n=2)
word / target label
0 [('15', 'O'), ('-', 'O'), ('Christian', 'B-PER'), ('Cullen', 'I-PER'), (',', 'O'), ('14', 'O'), ('-', 'O'), ('Jeff', 'B-PER'), ('Wilson', 'I-PER'), (',', 'O'), ('13', 'O'), ('-', 'O'), ('Walter', 'B-PER'), ('Little', 'I-PER'), (',', 'O'), ('12', 'O'), ('-', 'O'), ('Frank', 'B-PER'), ('Bunce', 'I-PER'), (',', 'O'), ('11', 'O'), ('-', 'O'), ('Glen', 'B-PER'), ('Osborne', 'I-PER'), (';', 'O'), ('10', 'O'), ('-', 'O'), ('Andrew', 'B-PER'), ('Mehrtens', 'I-PER'), (',', 'O'), ('9', 'O'), ('-', 'O'), ('Justin', 'B-PER'), ('Marshall', 'I-PER'), (';', 'O'), ('8', 'O'), ('-', 'O'), ('Zinzan', 'B-PER'), ('Brooke', 'I-PER'), (',', 'O'), ('7', 'O'), ('-', 'O'), ('Josh', 'B-PER'), ('Kronfeld', 'I-PER'), (',', 'O'), ('6', 'O'), ('-', 'O'), ('Michael', 'B-PER'), ('Jones', 'I-PER'), (',', 'O'), ('5', 'O'), ('-', 'O'), ('Ian', 'B-PER'), ('Jones', 'I-PER'), (',', 'O'), ('4', 'O'), ('-', 'O'), ('Robin', 'B-PER'), ('Brooke', 'I-PER'), (',', 'O'), ('3', 'O'), ('-', 'O'), ('Olo', 'B-PER'), ('Brown', 'I-PER'), (',', 'O'), ('2', 'O'), ('-', 'O'), ('Sean', 'B-PER'), ('Fitzpatrick', 'I-PER'), ('(', 'O'), ('captain', 'O'), (')', 'O'), (',', 'O'), ('1', 'O'), ('-', 'O'), ('Craig', 'B-PER'), ('Dowd', 'I-PER'), ('.', 'O')]
1 [('A', 'O'), ('super', 'O'), ('piece', 'O'), ('of', 'O'), ('fielding', 'O'), ('by', 'O'), ('Lewis', 'B-PER'), (',', 'O'), ('dropped', 'O'), ('as', 'O'), ('a', 'O'), ('disciplinary', 'O'), ('measure', 'O'), ('after', 'O'), ('arriving', 'O'), ('only', 'O'), ('35', 'O'), ('minutes', 'O'), ('before', 'O'), ('the', 'O'), ('start', 'O'), ('on', 'O'), ('the', 'O'), ('fourth', 'O'), ('morning', 'O'), (',', 'O'), ('provided', 'O'), ('the', 'O'), ('only', 'O'), ('bright', 'O'), ('spot', 'O'), ('for', 'O'), ('England', 'B-LOC'), ('as', 'O'), ('the', 'O'), ('touring', 'O'), ('team', 'O'), ('batted', 'O'), ('on', 'O'), ('to', 'O'), ('reach', 'O'), ('413', 'O'), ('for', 'O'), ('five', 'O'), ('at', 'O'), ('the', 'O'), ('interval', 'O'), (',', 'O'), ('a', 'O'), ('lead', 'O'), ('of', 'O'), ('87', 'O'), ('.', 'O')]

Mid-level API

In this section, we’ll add helpful metrics for token classification tasks


source

calculate_token_class_metrics

 calculate_token_class_metrics (pred_toks, targ_toks, metric_key)

source

TokenClassMetricsCallback

 TokenClassMetricsCallback (tok_metrics=['accuracy', 'precision',
                            'recall', 'f1'], **kwargs)

A fastai friendly callback that includes accuracy, precision, recall, and f1 metrics using the seqeval library. Additionally, this metric knows how to not include your ‘ignore_token’ in it’s calculations.

See here for more information on seqeval.

Example

Training

model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [TokenClassMetricsCallback()]

learn = Learner(dls, model, opt_func=partial(Adam), loss_func=PreCalculatedCrossEntropyLoss(), cbs=learn_cbs, splitter=blurr_splitter)

learn.freeze()
learn.summary()
b = dls.one_batch()
preds = learn.model(b[0])
len(preds), type(preds), preds.keys()
(2,
 transformers.modeling_outputs.TokenClassifierOutput,
 odict_keys(['loss', 'logits']))
len(b), len(b[0]), b[0]["input_ids"].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([4, 88]), 4, torch.Size([4, 88]))
# b[0]["labels"].shape
preds.logits.shape
torch.Size([4, 88, 9])
print(preds.logits.view(-1, preds.logits.shape[-1]).shape, b[1].view(-1).shape)
test_eq(preds.logits.view(-1, preds.logits.shape[-1]).shape[0], b[1].view(-1).shape[0])
torch.Size([352, 9]) torch.Size([352])
print(len(learn.opt.param_groups))
3
learn.unfreeze()
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
SuggestedLRs(minimum=0.0009120108559727668, steep=6.30957365501672e-05, valley=0.00013182566908653826, slide=4.365158383734524e-05)

learn.fit_one_cycle(1, lr_max=3e-5, moms=(0.8, 0.7, 0.8), cbs=fit_cbs)
epoch train_loss valid_loss accuracy precision recall f1 time
0 0.065990 0.049054 0.989001 0.941656 0.930676 0.936134 03:12
print(learn.token_classification_report)
              precision    recall  f1-score   support

         LOC       0.97      0.95      0.96      1439
        MISC       0.89      0.87      0.88       737
         ORG       0.91      0.89      0.90      1241
         PER       0.97      0.98      0.97      1300

   micro avg       0.94      0.93      0.94      4717
   macro avg       0.93      0.92      0.93      4717
weighted avg       0.94      0.93      0.94      4717

Showing results

Below we’ll add in additional functionality to more intuitively show the results of our model.

learn.show_results(learner=learn, max_n=2, trunc_at=10)
token / target label / predicted label
0 [('MARKET', 'O', 'O'), ('TALK', 'O', 'O'), ('-', 'O', 'O'), ('USDA', 'B-ORG', 'B-ORG'), ('net', 'O', 'O'), ('change', 'O', 'O'), ('in', 'O', 'O'), ('weekly', 'O', 'O'), ('export', 'O', 'O'), ('commitments', 'O', 'O')]
1 [('Innocent', 'B-PER', 'B-PER'), ('Butare', 'I-PER', 'I-PER'), (',', 'O', 'O'), ('executive', 'O', 'O'), ('secretary', 'O', 'O'), ('of', 'O', 'O'), ('the', 'O', 'O'), ('Rally', 'B-ORG', 'B-ORG'), ('for', 'I-ORG', 'I-ORG'), ('the', 'I-ORG', 'I-ORG')]

Prediction

The default Learner.predict method returns a prediction per subtoken, including the special tokens for each architecture’s tokenizer. Starting with version 2.0 of BLURR, we bring token prediction in-line with Hugging Face’s token classification pipeline, both in terms of supporting the same aggregation strategies via Blurr’s TokenAggregationStrategies class, and also the output via BLURR’s @patched Learner method, blurr_predict_tokens.


source

TokenAggregationStrategies

 TokenAggregationStrategies (hf_tokenizer:transformers.tokenization_utils_
                             base.PreTrainedTokenizerBase,
                             labels:List[str], non_entity_label:str='O')

Provides the equivalanet of Hugging Face’s token classification pipeline’s aggregation_strategy support across various token classication tasks (e.g, NER, POS, chunking, etc…)


source

Learner.blurr_predict_tokens

 Learner.blurr_predict_tokens (items:Union[str,List[str]],
                               aggregation_strategy:str='simple',
                               non_entity_label:str='O',
                               slow_word_ids_func:Optional[Callable]=None)
Type Default Details
items Union The str (or list of strings) you want to get token classification predictions for
aggregation_strategy str simple How entities are grouped and scored
non_entity_label str O The label used to idendity non-entity related words/tokens
slow_word_ids_func Optional None If using a slow tokenizer, users will need to prove a slow_word_ids_func that accepts a

tokenizzer, example index, and a batch encoding as arguments and in turn returnes the equavlient of fast tokenizer’s `word_ids`` |


source

Learner.blurr_predict_tokens

 Learner.blurr_predict_tokens (items:Union[str,List[str]],
                               aggregation_strategy:str='simple',
                               non_entity_label:str='O',
                               slow_word_ids_func:Optional[Callable]=None)
Type Default Details
items Union The str (or list of strings) you want to get token classification predictions for
aggregation_strategy str simple How entities are grouped and scored
non_entity_label str O The label used to idendity non-entity related words/tokens
slow_word_ids_func Optional None If using a slow tokenizer, users will need to prove a slow_word_ids_func that accepts a

tokenizzer, example index, and a batch encoding as arguments and in turn returnes the equavlient of fast tokenizer’s `word_ids`` |

res = learn.blurr_predict_tokens(
    items=["My name is Wayde and I live in San Diego and using Hugging Face", "Bayern Munich is a soccer team in Germany"],
    aggregation_strategy="max",
)

print(len(res))
print(res[1])
2
[{'entity_group': 'ORG', 'score': 0.9952805638313293, 'word': 'Bayern Munich', 'start': 0, 'end': 13}, {'entity_group': 'LOC', 'score': 0.9980798959732056, 'word': 'Germany', 'start': 34, 'end': 41}]
txt = "Hi! My name is Wayde Gilliam from ohmeow.com. I live in California."
txt2 = "I wish covid was over so I could go to Germany and watch Bayern Munich play in the Bundesliga."
res = learn.blurr_predict_tokens(txt)
print(res)
[[{'entity_group': 'PER', 'score': 0.8786835372447968, 'word': 'Wayde Gilliam', 'start': 15, 'end': 28}, {'entity_group': 'PER', 'score': 0.30589407682418823, 'word': 'oh', 'start': 34, 'end': 36}, {'entity_group': 'ORG', 'score': 0.31899651885032654, 'word': 'meow', 'start': 36, 'end': 40}, {'entity_group': 'LOC', 'score': 0.1332944929599762, 'word': '.', 'start': 44, 'end': 45}, {'entity_group': 'LOC', 'score': 0.9964601397514343, 'word': 'California', 'start': 56, 'end': 66}, {'entity_group': 'LOC', 'score': 0.13329452276229858, 'word': '.', 'start': 66, 'end': 67}]]
results = learn.blurr_predict_tokens([txt, txt2])
for res in results:
    print(f"{res}\n")
[{'entity_group': 'PER', 'score': 0.8786835372447968, 'word': 'Wayde Gilliam', 'start': 15, 'end': 28}, {'entity_group': 'PER', 'score': 0.30589407682418823, 'word': 'oh', 'start': 34, 'end': 36}, {'entity_group': 'ORG', 'score': 0.31899651885032654, 'word': 'meow', 'start': 36, 'end': 40}, {'entity_group': 'LOC', 'score': 0.1332944929599762, 'word': '.', 'start': 44, 'end': 45}, {'entity_group': 'LOC', 'score': 0.9964601397514343, 'word': 'California', 'start': 56, 'end': 66}, {'entity_group': 'LOC', 'score': 0.13329452276229858, 'word': '.', 'start': 66, 'end': 67}]

[{'entity_group': 'LOC', 'score': 0.9927727580070496, 'word': 'Germany', 'start': 39, 'end': 46}, {'entity_group': 'ORG', 'score': 0.9933880269527435, 'word': 'Bayern Munich', 'start': 57, 'end': 70}, {'entity_group': 'MISC', 'score': 0.9065296053886414, 'word': 'Bundesliga', 'start': 83, 'end': 93}, {'entity_group': 'ORG', 'score': 0.12906739115715027, 'word': '.', 'start': 93, 'end': 94}]

Inference

export_fname = "tok_class_learn_export"
learn.export(fname=f"{export_fname}.pkl")
inf_learn = load_learner(fname=f"{export_fname}.pkl")

results = inf_learn.blurr_predict_tokens([txt, txt2])
for res in results:
    print(f"{res}\n")
[{'entity_group': 'PER', 'score': 0.8786836713552475, 'word': 'Wayde Gilliam', 'start': 15, 'end': 28}, {'entity_group': 'PER', 'score': 0.30589422583580017, 'word': 'oh', 'start': 34, 'end': 36}, {'entity_group': 'ORG', 'score': 0.31899629533290863, 'word': 'meow', 'start': 36, 'end': 40}, {'entity_group': 'LOC', 'score': 0.1332944929599762, 'word': '.', 'start': 44, 'end': 45}, {'entity_group': 'LOC', 'score': 0.9964601397514343, 'word': 'California', 'start': 56, 'end': 66}, {'entity_group': 'LOC', 'score': 0.13329453766345978, 'word': '.', 'start': 66, 'end': 67}]

[{'entity_group': 'LOC', 'score': 0.9927727580070496, 'word': 'Germany', 'start': 39, 'end': 46}, {'entity_group': 'ORG', 'score': 0.9933880269527435, 'word': 'Bayern Munich', 'start': 57, 'end': 70}, {'entity_group': 'MISC', 'score': 0.9065297245979309, 'word': 'Bundesliga', 'start': 83, 'end': 93}, {'entity_group': 'ORG', 'score': 0.12906737625598907, 'word': '.', 'start': 93, 'end': 94}]

High-level API


source

BlearnerForTokenClassification

 BlearnerForTokenClassification (dls:fastai.data.core.DataLoaders,
                                 hf_model:transformers.modeling_utils.PreT
                                 rainedModel, base_model_cb:blurr.text.mod
                                 eling.core.BaseModelCallback=<class 'blur
                                 r.text.modeling.core.BaseModelCallback'>,
                                 loss_func:callable|None=None,
                                 opt_func=<function Adam>, lr=0.001,
                                 splitter:callable=<function
                                 trainable_params>, cbs=None,
                                 metrics=None, path=None,
                                 model_dir='models', wd=None,
                                 wd_bn_bias=False, train_bn=True,
                                 moms=(0.95, 0.85, 0.95),
                                 default_cbs:bool=True)

Group together a model, some dls and a loss_func to handle training

Type Details
dls DataLoaders containing data for each dataset needed for model
hf_model PreTrainedModel

Example

Define your Blearner

hf_logging.set_verbosity_error()

learn = BlearnerForTokenClassification.from_data(
    conll2003_df,
    "distilroberta-base",
    tokens_attr="tokens",
    token_labels_attr="ner_tags",
    labels=labels,
    dl_kwargs={"bs": 2},
)

learn.unfreeze()
learn.dls.show_batch(dataloaders=learn.dls, max_n=2)
word / target label
0 [('MARKET', 'O'), ('TALK', 'O'), ('-', 'O'), ('USDA', 'B-ORG'), ('net', 'O'), ('change', 'O'), ('in', 'O'), ('weekly', 'O'), ('export', 'O'), ('commitments', 'O'), ('for', 'O'), ('the', 'O'), ('week', 'O'), ('ended', 'O'), ('August', 'O'), ('22', 'O'), (',', 'O'), ('includes', 'O'), ('old', 'O'), ('crop', 'O'), ('and', 'O'), ('new', 'O'), ('crop', 'O'), (',', 'O'), ('were', 'O'), (':', 'O'), ('wheat', 'O'), ('up', 'O'), ('595,400', 'O'), ('tonnes', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('corn', 'O'), ('up', 'O'), ('1,900', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('319,600', 'O'), ('new', 'O'), (';', 'O'), ('soybeans', 'O'), ('down', 'O'), ('12,300', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('300,800', 'O'), ('new', 'O'), (';', 'O'), ('upland', 'O'), ('cotton', 'O'), ('up', 'O'), ('50,400', 'O'), ('bales', 'O'), ('new', 'O'), (',', 'O'), ('nil', 'O'), ('old', 'O'), (';', 'O'), ('soymeal', 'O'), ('54,800', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('100,600', 'O'), ('new', 'O'), (',', 'O'), ('soyoil', 'O'), ('nil', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('75,000', 'O'), ('new', 'O'), (';', 'O'), ('barley', 'O'), ('up', 'O'), ('1,700', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('sorghum', 'O'), ('6,200', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('156,700', 'O'), ('new', 'O'), (';', 'O'), ('pima', 'O'), ('cotton', 'O'), ('up', 'O'), ('4,000', 'O'), ('bales', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('rice', 'O'), ('up', 'O'), ('49,900', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), ('...', 'O')]
1 [('The', 'O'), ('Pirates', 'B-ORG'), (',', 'O'), ('who', 'O'), ('conceded', 'O'), ('earlier', 'O'), ('this', 'O'), ('week', 'O'), ('they', 'O'), ('would', 'O'), ('be', 'O'), ('forced', 'O'), ('to', 'O'), ('trim', 'O'), ('salary', 'O'), ('from', 'O'), ('next', 'O'), ('season', 'O'), ("'s", 'O'), ('payroll', 'O'), (',', 'O'), ('received', 'O'), ('Ron', 'B-PER'), ('Wright', 'I-PER'), (',', 'O'), ('a', 'O'), ('first', 'O'), ('baseman', 'O'), ('at', 'O'), ('Double-A', 'O'), ('Greenville', 'B-ORG'), (';', 'O'), ('Corey', 'B-PER'), ('Pointer', 'I-PER'), (',', 'O'), ('a', 'O'), ('pitcher', 'O'), ('at', 'O'), ('Class-A', 'O'), ('Eugene', 'B-ORG'), (',', 'O'), ('and', 'O'), ('a', 'O'), ('player', 'O'), ('to', 'O'), ('be', 'O'), ('named', 'O'), ('.', 'O')]

Train

learn.fit_one_cycle(1, lr_max=3e-5, moms=(0.8, 0.7, 0.8), cbs=[BlearnerForTokenClassification.get_metrics_cb()])
epoch train_loss valid_loss accuracy precision recall f1 time
0 0.066553 0.050842 0.988192 0.934066 0.930526 0.932293 03:59
learn.show_results(learner=learn, max_n=2, trunc_at=10)
token / target label / predicted label
0 [('Innocent', 'B-PER', 'B-PER'), ('Butare', 'I-PER', 'I-PER'), (',', 'O', 'O'), ('executive', 'O', 'O'), ('secretary', 'O', 'O'), ('of', 'O', 'O'), ('the', 'O', 'O'), ('Rally', 'B-ORG', 'B-ORG'), ('for', 'I-ORG', 'I-ORG'), ('the', 'I-ORG', 'I-ORG')]
1 [('"', 'O', 'O'), ('I', 'O', 'O'), ('do', 'O', 'O'), ("n't", 'O', 'O'), ('know', 'O', 'O'), ('what', 'O', 'O'), ('the', 'O', 'O'), ('source', 'O', 'O'), ('of', 'O', 'O'), ('the', 'O', 'O')]
print(learn.token_classification_report)
              precision    recall  f1-score   support

         LOC       0.95      0.96      0.95      1401
        MISC       0.87      0.89      0.88       676
         ORG       0.91      0.89      0.90      1372
         PER       0.98      0.97      0.97      1301

   micro avg       0.93      0.93      0.93      4750
   macro avg       0.93      0.93      0.93      4750
weighted avg       0.93      0.93      0.93      4750

Prediction

txt = "Hi! My name is Wayde Gilliam from ohmeow.com. I live in California."
txt2 = "I wish covid was over so I could watch Lewandowski score some more goals for Bayern Munich in the Bundesliga."
results = learn.predict([txt, txt2])
for res in results:
    print(f"{res}\n")
[{'entity_group': 'PER', 'score': 0.9960938096046448, 'word': 'Way', 'start': 15, 'end': 18}, {'entity_group': 'PER', 'score': 0.9659706950187683, 'word': 'de Gilliam', 'start': 18, 'end': 28}, {'entity_group': 'ORG', 'score': 0.4243549704551697, 'word': 'ohmeow', 'start': 34, 'end': 40}, {'entity_group': 'LOC', 'score': 0.9980984330177307, 'word': 'California', 'start': 56, 'end': 66}]

[{'entity_group': 'PER', 'score': 0.9737088978290558, 'word': 'Lewandowski', 'start': 39, 'end': 50}, {'entity_group': 'ORG', 'score': 0.983256071805954, 'word': 'Bayern Munich', 'start': 77, 'end': 90}, {'entity_group': 'MISC', 'score': 0.957284688949585, 'word': 'Bundesliga', 'start': 98, 'end': 108}]

Tests

The tests below to ensure the token classification training code above works for all pretrained token classification models available in Hugging Face. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained token classification models you are working with … and if any of your pretrained token classification models fail, please submit a github issue (or a PR if you’d like to fix it yourself)

raw_datasets = load_dataset("conll2003")
labels = raw_datasets["train"].features["ner_tags"].feature.names
conll2003_df = pd.DataFrame(raw_datasets["train"])
Reusing dataset conll2003 (/home/wgilliam/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)
arch tokenizer model_name result error
0 albert AlbertTokenizerFast AlbertForTokenClassification PASSED
1 bert BertTokenizerFast BertForTokenClassification PASSED
2 big_bird BigBirdTokenizerFast BigBirdForTokenClassification PASSED
3 camembert CamembertTokenizerFast CamembertForTokenClassification PASSED
4 convbert ConvBertTokenizerFast ConvBertForTokenClassification PASSED
5 deberta DebertaTokenizerFast DebertaForTokenClassification PASSED
6 bert BertTokenizerFast BertForTokenClassification PASSED
7 electra ElectraTokenizerFast ElectraForTokenClassification PASSED
8 funnel FunnelTokenizerFast FunnelForTokenClassification PASSED
9 gpt2 GPT2TokenizerFast GPT2ForTokenClassification PASSED
10 layoutlm LayoutLMTokenizerFast LayoutLMForTokenClassification PASSED
11 longformer LongformerTokenizerFast LongformerForTokenClassification PASSED
12 mpnet MPNetTokenizerFast MPNetForTokenClassification PASSED
13 ibert RobertaTokenizerFast IBertForTokenClassification PASSED
14 mobilebert MobileBertTokenizerFast MobileBertForTokenClassification PASSED
15 rembert RemBertTokenizerFast RemBertForTokenClassification PASSED
16 roformer RoFormerTokenizerFast RoFormerForTokenClassification PASSED
17 roberta RobertaTokenizerFast RobertaForTokenClassification PASSED
18 squeezebert SqueezeBertTokenizerFast SqueezeBertForTokenClassification PASSED
19 xlm_roberta XLMRobertaTokenizerFast XLMRobertaForTokenClassification PASSED
20 xlnet XLNetTokenizerFast XLNetForTokenClassification PASSED