blurr
  • Getting Started
  • Resources
    • fastai x Hugging Face Study Group
    • Hugging Face Course
    • fast.ai (docs)
    • transformers (docs)
  • Help
    • Report an Issue

Data

  • Overview
    • Getting Started
    • callbacks
    • utils
  • Text
    • Sequence Classification
      • Data
      • Modeling
    • Token Classification
      • Data
      • Modeling
    • Question & Answering
      • Data
      • Modeling
    • Language Modeling
      • Data
      • Modeling
    • Seq2Seq: Core
      • Data
      • Modeling
    • Seq2Seq: Summarization
      • Data
      • Modeling
    • Seq2Seq: Translation
      • Data
      • Modeling
    • callbacks
    • utils
  • Examples
    • Using the high-level Blurr API
    • GLUE classification tasks
    • Using the Low-level fastai API
    • Multi-label classification
    • Causal Language Modeling with GPT-2

On this page

  • Setup
  • Preprocessing
    • TokenClassPreprocessor
      • labels are Ids
      • labels are entity names
  • Labeling strategies
    • BaseLabelingStrategy
    • BILabelingStrategy
    • SameLabelLabelingStrategy
    • OnlyFirstTokenLabelingStrategy
    • Reconstructing inputs/labels
    • get_token_labels_from_input_ids
    • get_word_labels_from_token_labels
  • Mid-level API
    • TokenTensorCategory
    • TokenCategorize
    • TokenCategoryBlock
    • TokenClassTextInput
    • TokenClassBatchTokenizeTransform
  • Examples
    • Using the mid-level API
      • Batch-Time Tokenization
      • Passing extra infromation
  • Tests

Report an issue

Data

The text.data.token_classification module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for token classification tasks (e.g., Named entity recognition (NER), Part-of-speech tagging (POS), etc…)

Setup

We’ll use a subset of conll2003 to demonstrate how to configure your blurr code for token classification

raw_datasets = load_dataset("conll2003")
raw_datasets
Reusing dataset conll2003 (/home/wgilliam/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)
DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

We need to get a list of the distinct entities we want to predict. If they are represented as list in their raw/readable form in another attribute/column in our dataset, we could use something like this to build a sorted list of distinct values as such: labels = sorted(list(set([lbls for sublist in germ_eval_df.labels.tolist() for lbls in sublist]))).

Fortunately, the conll2003 dataset allows us to get at this list directly using the code below.

print(raw_datasets["train"].features["chunk_tags"].feature.names[:20])
print(raw_datasets["train"].features["ner_tags"].feature.names[:20])
print(raw_datasets["train"].features["pos_tags"].feature.names[:20])
['O', 'B-ADJP', 'I-ADJP', 'B-ADVP', 'I-ADVP', 'B-CONJP', 'I-CONJP', 'B-INTJ', 'I-INTJ', 'B-LST', 'I-LST', 'B-NP', 'I-NP', 'B-PP', 'I-PP', 'B-PRT', 'I-PRT', 'B-SBAR', 'I-SBAR', 'B-UCP']
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
['"', "''", '#', '$', '(', ')', ',', '.', ':', '``', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS']
labels = raw_datasets["train"].features["ner_tags"].feature.names
labels
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
conll2003_df = pd.DataFrame(raw_datasets["train"])
model_cls = AutoModelForTokenClassification
hf_logging.set_verbosity_error()

pretrained_model_name = "roberta-base"  # "bert-base-multilingual-cased"
n_labels = len(labels)

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(
    pretrained_model_name, model_cls=model_cls, config_kwargs={"num_labels": n_labels}
)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('roberta',
 transformers.models.roberta.configuration_roberta.RobertaConfig,
 transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast,
 transformers.models.roberta.modeling_roberta.RobertaForTokenClassification)

Preprocessing

Starting with version 2.0, BLURR provides a token classification preprocessing class that can be used to preprocess DataFrames or Hugging Face Datasets. We also introduce a novel way of handling long documents for this task that ensures tokens associated to a word is not split up in “chunked” documents. See below for an example.


source

TokenClassPreprocessor

 TokenClassPreprocessor (hf_tokenizer:transformers.tokenization_utils_base
                         .PreTrainedTokenizerBase,
                         chunk_examples:bool=False, word_stride:int=2,
                         ignore_token_id:int=-100,
                         label_names:Optional[List[str]]=None,
                         batch_size:int=1000, id_attr:Optional[str]=None,
                         word_list_attr:str='tokens',
                         label_list_attr:str='labels',
                         is_valid_attr:Optional[str]='is_valid',
                         slow_word_ids_func:Optional[Callable]=None,
                         tok_kwargs:dict={})

Initialize self. See help(type(self)) for accurate signature.

Type Default Details
hf_tokenizer PreTrainedTokenizerBase A Hugging Face tokenizer
chunk_examples bool False Set to True if the preprocessor should chunk examples that exceed max_length
word_stride int 2 Like “stride” except for words (not tokens)
ignore_token_id int -100 The token ID that should be ignored when calculating the loss
label_names Optional None The label names (if not specified, will build from DataFrame)
batch_size int 1000 The number of examples to process at a time
id_attr Optional None The unique identifier in the dataset
word_list_attr str tokens The attribute holding the list of words
label_list_attr str labels The attribute holding the list of labels (one for each word in word_list_attr)
is_valid_attr Optional is_valid The attribute that should be created if your are processing individual training and validation
datasets into a single dataset, and will indicate to which each example is associated
slow_word_ids_func Optional None If using a slow tokenizer, users will need to prove a slow_word_ids_func that accepts a

tokenizzer, example index, and a batch encoding as arguments and in turn returnes the equavlient of fast tokenizer’s word_ids | | tok_kwargs | dict | {} | Tokenization kwargs that will be applied with calling the tokenizer |

labels are Ids

preprocessor = TokenClassPreprocessor(
    hf_tokenizer,
    chunk_examples=True,
    word_stride=2,
    label_names=labels,
    id_attr="id",
    word_list_attr="tokens",
    label_list_attr="ner_tags",
    tok_kwargs={"max_length": 8},
)
proc_df = preprocessor.process_df(conll2003_df)

print(len(proc_df))
print(preprocessor.label_names)
proc_df.head(4)
61298
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
proc_tokens proc_ner_tags id tokens pos_tags chunk_tags ner_tags
0 [EU, rejects, German, call, to, boycott] [3, 0, 7, 0, 0, 0] 0 [EU, rejects, German, call, to, boycott, British, lamb, .] [22, 42, 16, 21, 35, 37, 16, 21, 7] [11, 21, 11, 12, 21, 22, 11, 12, 0] [3, 0, 7, 0, 0, 0, 7, 0, 0]
1 [to, boycott, British, lamb, .] [0, 0, 7, 0, 0] 0 [EU, rejects, German, call, to, boycott, British, lamb, .] [22, 42, 16, 21, 35, 37, 16, 21, 7] [11, 21, 11, 12, 21, 22, 11, 12, 0] [3, 0, 7, 0, 0, 0, 7, 0, 0]
2 [Peter, Blackburn] [1, 2] 1 [Peter, Blackburn] [22, 22] [11, 12] [1, 2]
3 [BRUSSELS, 1996-08-22] [5, 0] 2 [BRUSSELS, 1996-08-22] [22, 11] [11, 12] [5, 0]

labels are entity names

conll2003_labeled_df = conll2003_df.copy()
conll2003_labeled_df.ner_tags = conll2003_labeled_df.ner_tags.apply(lambda v: [labels[lbl_id] for lbl_id in v])
conll2003_labeled_df.head(5)
id tokens pos_tags chunk_tags ner_tags
0 0 [EU, rejects, German, call, to, boycott, British, lamb, .] [22, 42, 16, 21, 35, 37, 16, 21, 7] [11, 21, 11, 12, 21, 22, 11, 12, 0] [B-ORG, O, B-MISC, O, O, O, B-MISC, O, O]
1 1 [Peter, Blackburn] [22, 22] [11, 12] [B-PER, I-PER]
2 2 [BRUSSELS, 1996-08-22] [22, 11] [11, 12] [B-LOC, O]
3 3 [The, European, Commission, said, on, Thursday, it, disagreed, with, German, advice, to, consumers, to, shun, British, lamb, until, scientists, determine, whether, mad, cow, disease, can, be, transmitted, to, sheep, .] [12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 35, 24, 35, 37, 16, 21, 15, 24, 41, 15, 16, 21, 21, 20, 37, 40, 35, 21, 7] [11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 13, 11, 21, 22, 11, 12, 17, 11, 21, 17, 11, 12, 12, 21, 22, 22, 13, 11, 0] [O, B-ORG, I-ORG, O, O, O, O, O, O, B-MISC, O, O, O, O, O, B-MISC, O, O, O, O, O, O, O, O, O, O, O, O, O, O]
4 4 [Germany, 's, representative, to, the, European, Union, 's, veterinary, committee, Werner, Zwingmann, said, on, Wednesday, consumers, should, buy, sheepmeat, from, countries, other, than, Britain, until, the, scientific, advice, was, clearer, .] [22, 27, 21, 35, 12, 22, 22, 27, 16, 21, 22, 22, 38, 15, 22, 24, 20, 37, 21, 15, 24, 16, 15, 22, 15, 12, 16, 21, 38, 17, 7] [11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 12, 21, 13, 11, 12, 21, 22, 11, 13, 11, 1, 13, 11, 17, 11, 12, 12, 21, 1, 0] [B-LOC, O, O, O, O, B-ORG, I-ORG, O, O, O, B-PER, I-PER, O, O, O, O, O, O, O, O, O, O, O, B-LOC, O, O, O, O, O, O, O]
preprocessor = TokenClassPreprocessor(
    hf_tokenizer, label_names=labels, id_attr="id", word_list_attr="tokens", label_list_attr="ner_tags", tok_kwargs={"max_length": 8}
)
proc_df = preprocessor.process_df(conll2003_labeled_df)

print(len(proc_df))
print(preprocessor.label_names)
proc_df.head(4)
14041
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
proc_tokens proc_ner_tags id tokens pos_tags chunk_tags ner_tags
0 [EU, rejects, German, call, to, boycott] [B-ORG, O, B-MISC, O, O, O] 0 [EU, rejects, German, call, to, boycott, British, lamb, .] [22, 42, 16, 21, 35, 37, 16, 21, 7] [11, 21, 11, 12, 21, 22, 11, 12, 0] [B-ORG, O, B-MISC, O, O, O, B-MISC, O, O]
1 [Peter, Blackburn] [B-PER, I-PER] 1 [Peter, Blackburn] [22, 22] [11, 12] [B-PER, I-PER]
2 [BRUSSELS, 1996-08-22] [B-LOC, O] 2 [BRUSSELS, 1996-08-22] [22, 11] [11, 12] [B-LOC, O]
3 [The, European, Commission, said, on, Thursday] [O, B-ORG, I-ORG, O, O, O] 3 [The, European, Commission, said, on, Thursday, it, disagreed, with, German, advice, to, consumers, to, shun, British, lamb, until, scientists, determine, whether, mad, cow, disease, can, be, transmitted, to, sheep, .] [12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 35, 24, 35, 37, 16, 21, 15, 24, 41, 15, 16, 21, 21, 20, 37, 40, 35, 21, 7] [11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 13, 11, 21, 22, 11, 12, 17, 11, 21, 17, 11, 12, 12, 21, 22, 22, 13, 11, 0] [O, B-ORG, I-ORG, O, O, O, O, O, O, B-MISC, O, O, O, O, O, B-MISC, O, O, O, O, O, O, O, O, O, O, O, O, O, O]

Labeling strategies


source

BaseLabelingStrategy

 BaseLabelingStrategy (hf_tokenizer:transformers.tokenization_utils_base.P
                       reTrainedTokenizerBase,
                       label_names:Optional[List[str]],
                       non_entity_label:str='O', ignore_token_id:int=-100)

Initialize self. See help(type(self)) for accurate signature.

Here we include a BaseLabelingStrategy abstract class and several different strategies for assigning labels to your tokenized inputs. The “only first token” and “B/I” labeling strategies are discussed in the “Token Classification” section in part 7 of the Hugging Face’s Transformers course.


source

BILabelingStrategy

 BILabelingStrategy (hf_tokenizer:transformers.tokenization_utils_base.Pre
                     TrainedTokenizerBase,
                     label_names:Optional[List[str]],
                     non_entity_label:str='O', ignore_token_id:int=-100)

If using B/I labels, the first token assoicated to a given word gets the “B” label while all other tokens related to that same word get “I” labels. If “I” labels don’t exist, this strategy behaves like the OnlyFirstTokenLabelingStrategy. Works where labels are Ids or strings (in the later case we’ll use the label_names to look up it’s Id)


source

SameLabelLabelingStrategy

 SameLabelLabelingStrategy (hf_tokenizer:transformers.tokenization_utils_b
                            ase.PreTrainedTokenizerBase,
                            label_names:Optional[List[str]],
                            non_entity_label:str='O',
                            ignore_token_id:int=-100)

Every token associated with a given word is associated with the word’s label. Works where labels are Ids or strings (in the later case we’ll use the label_names to look up it’s Id)


source

OnlyFirstTokenLabelingStrategy

 OnlyFirstTokenLabelingStrategy (hf_tokenizer:transformers.tokenization_ut
                                 ils_base.PreTrainedTokenizerBase,
                                 label_names:Optional[List[str]],
                                 non_entity_label:str='O',
                                 ignore_token_id:int=-100)

Only the first token of word is associated with the label (all other subtokens with the ignore_index_id). Works where labels are Ids or strings (in the later case we’ll use the label_names to look up it’s Id)

Reconstructing inputs/labels

The utility methods below allow blurr users to reconstruct the original word/label associations from the input_ids/label associations. For example, these are used in our token classification show_batch method below.

# TESTS for align_labels_with_tokens()
for idx in range(3):
    raw_word_list = conll2003_df.iloc[idx]["tokens"]
    raw_label_list = conll2003_df.iloc[idx]["ner_tags"]

    be = hf_tokenizer(raw_word_list, is_split_into_words=True)
    input_ids = be["input_ids"]
    targ_ids = [-100 if (word_id == None) else raw_label_list[word_id] for word_id in be.word_ids()]

    tok_labels = get_token_labels_from_input_ids(hf_tokenizer, input_ids, targ_ids, labels)

    for tok_label, targ_id in zip(tok_labels, [label_id for label_id in targ_ids if label_id != -100]):
        test_eq(tok_label[1], labels[targ_id])

source

get_token_labels_from_input_ids

 get_token_labels_from_input_ids (hf_tokenizer:transformers.tokenization_u
                                  tils_base.PreTrainedTokenizerBase,
                                  input_ids:List[int],
                                  token_label_ids:List[int],
                                  vocab:List[str],
                                  ignore_token_id:int=-100,
                                  ignore_token:str='[xIGNx]')

Given a list of input IDs, the label ID associated to each, and the labels vocab, this method will return a list of tuples whereby each tuple defines the “token” and its label name. For example: [(‘ĠWay’, B-PER), (‘de’, B-PER), (‘ĠGill’, I-PER), (‘iam’, I-PER), (‘Ġloves’), (‘ĠHug’, B-ORG), (‘ging’, B-ORG), (‘ĠFace’, I-ORG)]

Type Default Details
hf_tokenizer PreTrainedTokenizerBase A Hugging Face tokenizer
input_ids List List of input_ids for the tokens in a single piece of processed text
token_label_ids List List of label indexs for each token
vocab List List of label names from witch the label indicies can be used to find the name of the label
ignore_token_id int -100 The token ID that should be ignored when calculating the loss
ignore_token str [xIGNx] The token used to identifiy ignored tokens (default: [xIGNx])
Returns List
# TESTS for align_labels_with_words()
for idx in range(5):
    raw_word_list = conll2003_df.iloc[idx]["tokens"]
    raw_label_list = conll2003_df.iloc[idx]["ner_tags"]

    be = hf_tokenizer(raw_word_list, is_split_into_words=True)
    input_ids = be["input_ids"]
    targ_ids = [-100 if (word_id == None) else raw_label_list[word_id] for word_id in be.word_ids()]

    tok_labels = get_token_labels_from_input_ids(hf_tokenizer, input_ids, targ_ids, labels)
    word_labels = get_word_labels_from_token_labels(hf_arch, hf_tokenizer, tok_labels)

    for word_label, raw_word, raw_label_id in zip(word_labels, raw_word_list, raw_label_list):
        test_eq(word_label[0], raw_word)
        test_eq(word_label[1], labels[raw_label_id])

source

get_word_labels_from_token_labels

 get_word_labels_from_token_labels (hf_arch:str, hf_tokenizer:transformers
                                    .tokenization_utils_base.PreTrainedTok
                                    enizerBase, tok_labels)

Given a list of tuples where each tuple defines a token and its label, return a list of tuples whereby each tuple defines the “word” and its label. Method assumes that model inputs are a list of words, and in conjunction with the align_labels_with_tokens method, allows the user to reconstruct the orginal raw inputs and labels.

Type Details
hf_arch str
hf_tokenizer PreTrainedTokenizerBase A Hugging Face tokenizer
tok_labels A list of tuples, where each represents a token and its label (e.g., [(‘ĠHug’, B-ORG), (‘ging’, B-ORG), (‘ĠFace’, I-ORG), …])
Returns List

Mid-level API


source

TokenTensorCategory

 TokenTensorCategory (x, **kwargs)

A Tensor which support subclass pickling, and maintains metadata when casting or after methods


source

TokenCategorize

 TokenCategorize (vocab:List[str]=None, ignore_token:str='[xIGNx]',
                  ignore_token_id:int=-100)

Reversible transform of a list of category string to vocab id

Type Default Details
vocab List None The unique list of entities (e.g., B-LOC) (default: CategoryMap(vocab))
ignore_token str [xIGNx] The token used to identifiy ignored tokens (default: xIGNx)
ignore_token_id int -100 The token ID that should be ignored when calculating the loss (default: CrossEntropyLossFlat().ignore_index)

TokenCategorize modifies the fastai Categorize transform in a couple of ways.

First, it allows your targets to consist of a Category per token, and second, it uses the idea of an ignore_token_id to mask subtokens that don’t need a prediction. For example, the target of special tokens (e.g., pad, cls, sep) are set to ignore_token_id as are subsequent sub-tokens of a given token should more than 1 sub-token make it up.


source

TokenCategoryBlock

 TokenCategoryBlock (vocab:Optional[List[str]]=None,
                     ignore_token:str='[xIGNx]', ignore_token_id:int=-100)

TransformBlock for per-token categorical targets

Type Default Details
vocab Optional None The unique list of entities (e.g., B-LOC) (default: CategoryMap(vocab))
ignore_token str [xIGNx] The token used to identifiy ignored tokens (default: xIGNx)
ignore_token_id int -100 The token ID that should be ignored when calculating the loss (default: CrossEntropyLossFlat().ignore_index)

source

TokenClassTextInput

 TokenClassTextInput (x, **kwargs)

The base represenation of your inputs; used by the various fastai show methods

Again, we define a custom class, TokenClassTextInput, for the @typedispatched methods to use so that we can override how token classification inputs/targets are assembled, as well as, how the data is shown via methods like show_batch and show_results.


source

TokenClassBatchTokenizeTransform

 TokenClassBatchTokenizeTransform (hf_arch:str,
                                   hf_config:transformers.configuration_ut
                                   ils.PretrainedConfig, hf_tokenizer:tran
                                   sformers.tokenization_utils_base.PreTra
                                   inedTokenizerBase, hf_model:transformer
                                   s.modeling_utils.PreTrainedModel,
                                   include_labels:bool=True,
                                   ignore_token_id:int=-100, labeling_stra
                                   tegy_cls:__main__.BaseLabelingStrategy=
                                   <class '__main__.OnlyFirstTokenLabeling
                                   Strategy'>, target_label_names:Optional
                                   [List[str]]=None,
                                   non_entity_label:str='O',
                                   max_length:Optional[int]=None,
                                   padding:Union[bool,str]=True,
                                   truncation:Union[bool,str]=True,
                                   is_split_into_words:bool=True, slow_wor
                                   d_ids_func:Optional[Callable]=None,
                                   tok_kwargs:dict={}, **kwargs)

Handles everything you need to assemble a mini-batch of inputs and targets, as well as decode the dictionary produced as a byproduct of the tokenization process in the encodes method.

Type Default Details
hf_arch str The abbreviation/name of your Hugging Face transformer architecture (e.b., bert, bart, etc..)
hf_config PretrainedConfig A specific configuration instance you want to use
hf_tokenizer PreTrainedTokenizerBase A Hugging Face tokenizer
hf_model PreTrainedModel A Hugging Face model
include_labels bool True To control whether the “labels” are included in your inputs. If they are, the loss will be calculated in
the model’s forward function and you can simply use PreCalculatedLoss as your Learner’s loss function to use it
ignore_token_id int -100 The token ID that should be ignored when calculating the loss
labeling_strategy_cls BaseLabelingStrategy OnlyFirstTokenLabelingStrategy The labeling strategy you want to apply when associating labels with word tokens
target_label_names Optional None the target label names
non_entity_label str O the label for non-entity
max_length Optional None To control the length of the padding/truncation. It can be an integer or None,

in which case it will default to the maximum length the model can accept. If the model has no specific maximum input length, truncation/padding to max_length is deactivated. See Everything you always wanted to know about padding and truncation | | padding | Union | True | To control the padding applied to your hf_tokenizer during tokenization. If None, will default to False or 'do_not_pad'. See [Everything you always wanted to know about padding and truncation](https://huggingface.co/transformers/preprocessing.html#everything-you-always-wanted-to-know-about-padding-and-truncation) | | truncation | Union | True | To controltruncationapplied to yourhf_tokenizerduring tokenization. If None, will default toFalseordo_not_truncate. See [Everything you always wanted to know about padding and truncation](https://huggingface.co/transformers/preprocessing.html#everything-you-always-wanted-to-know-about-padding-and-truncation) | | is_split_into_words | bool | True | Theis_split_into_wordsargument applied to yourhf_tokenizerduring tokenization. Set this toTrueif your inputs are pre-tokenized (not numericalized) | | slow_word_ids_func | Optional | None | If using a slow tokenizer, users will need to prove aslow_word_ids_functhat accepts a tokenizzer, example index, and a batch encoding as arguments and in turn returnes the equavlient of fast tokenizer'sword_ids`| | tok_kwargs | dict | {} | Any other keyword arguments you want included when using yourhf_tokenizer` to tokenize your inputs | | kwargs | | | |

TokenClassBatchTokenizeTransform is used to exclude any of the target’s tokens we don’t want to include in the loss calcuation (e.g. padding, cls, sep, etc…).

Note also that we default is_split_into_words = True since token classification tasks expect a list of words and labels for each word.

Examples

Using the mid-level API

Batch-Time Tokenization

Step 1: Get your Hugging Face objects.
hf_logging.set_verbosity_error()

pretrained_model_name = "distilroberta-base"
n_labels = len(labels)

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(
    pretrained_model_name, model_cls=AutoModelForTokenClassification, config_kwargs={"num_labels": n_labels}
)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('roberta',
 transformers.models.roberta.configuration_roberta.RobertaConfig,
 transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast,
 transformers.models.roberta.modeling_roberta.RobertaForTokenClassification)
Step 2: Create your DataBlock
batch_tok_tfm = TokenClassBatchTokenizeTransform(
    hf_arch, hf_config, hf_tokenizer, hf_model, labeling_strategy_cls=BILabelingStrategy, target_label_names=labels
)
blocks = (TextBlock(batch_tokenize_tfm=batch_tok_tfm, input_return_type=TokenClassTextInput), TokenCategoryBlock(vocab=labels))

dblock = DataBlock(blocks=blocks, get_x=ColReader("tokens"), get_y=ColReader("ner_tags"), splitter=RandomSplitter())
Step 3: Build your DataLoaders
dls = dblock.dataloaders(conll2003_df, bs=4)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[1].shape
(2, torch.Size([4, 156]), torch.Size([4, 156]))
dls.show_batch(dataloaders=dls, max_n=5, trunc_at=20)
word / target label
0 [('MARKET', 'O'), ('TALK', 'O'), ('-', 'O'), ('USDA', 'B-ORG'), ('net', 'O'), ('change', 'O'), ('in', 'O'), ('weekly', 'O'), ('export', 'O'), ('commitments', 'O'), ('for', 'O'), ('the', 'O'), ('week', 'O'), ('ended', 'O'), ('August', 'O'), ('22', 'O'), (',', 'O'), ('includes', 'O'), ('old', 'O'), ('crop', 'O')]
1 [('Slough', 'B-ORG'), ("'s", 'O'), ('chairman', 'O'), ('Sir', 'O'), ('Nigel', 'B-PER'), ('Mobbs', 'I-PER'), ('added', 'O'), ('to', 'O'), ('the', 'O'), ('bullish', 'O'), ('mood', 'O'), ('in', 'O'), ('the', 'O'), ('sector', 'O'), (',', 'O'), ('saying', 'O'), ('in', 'O'), ('a', 'O'), ('statement', 'O'), ('that', 'O')]
2 [('The', 'O'), ('government-owned', 'O'), ('al-Ingaz', 'B-ORG'), ('al-Watani', 'I-ORG'), ('said', 'O'), ('the', 'O'), ('smugglers', 'O'), ('were', 'O'), ('caught', 'O'), ('in', 'O'), ('Banat', 'B-LOC'), ('in', 'O'), ('the', 'O'), ('eastern', 'O'), ('state', 'O'), ('of', 'O'), ('Kassala', 'B-LOC'), (',', 'O'), ('on', 'O'), ('the', 'O')]
3 [('"', 'O'), ('The', 'O'), ('ultimatum', 'O'), ('(', 'O'), ('to', 'O'), ('storm', 'O'), ('Grozny', 'B-LOC'), (')', 'O'), ('is', 'O'), ('no', 'O'), ('longer', 'O'), ('an', 'O'), ('issue', 'O'), (',', 'O'), ('"', 'O'), ('he', 'O'), ('said', 'O'), ('quoting', 'O'), ('Ischinger', 'B-PER'), (',', 'O')]

Passing extra infromation

Step 1b: Get your Hugging Face objects.
hf_logging.set_verbosity_error()

pretrained_model_name = "distilroberta-base"
n_labels = len(labels)

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(
    pretrained_model_name, model_cls=AutoModelForTokenClassification, config_kwargs={"num_labels": n_labels}
)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('roberta',
 transformers.models.roberta.configuration_roberta.RobertaConfig,
 transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast,
 transformers.models.roberta.modeling_roberta.RobertaForTokenClassification)
Step 1b. Preprocess dataset
preprocessor = TokenClassPreprocessor(
    hf_tokenizer,
    label_names=labels,
    id_attr="id",
    word_list_attr="tokens",
    label_list_attr="ner_tags",
    tok_kwargs={"max_length": 128},
)
proc_df = preprocessor.process_df(conll2003_df)
proc_df.head(2)
proc_tokens proc_ner_tags id tokens pos_tags chunk_tags ner_tags
0 [EU, rejects, German, call, to, boycott, British, lamb, .] [3, 0, 7, 0, 0, 0, 7, 0, 0] 0 [EU, rejects, German, call, to, boycott, British, lamb, .] [22, 42, 16, 21, 35, 37, 16, 21, 7] [11, 21, 11, 12, 21, 22, 11, 12, 0] [3, 0, 7, 0, 0, 0, 7, 0, 0]
1 [Peter, Blackburn] [1, 2] 1 [Peter, Blackburn] [22, 22] [11, 12] [1, 2]
Step 2: Create your DataBlock
batch_tok_tfm = TokenClassBatchTokenizeTransform(hf_arch, hf_config, hf_tokenizer, hf_model, target_label_names=labels)
blocks = (TextBlock(batch_tokenize_tfm=batch_tok_tfm, input_return_type=TokenClassTextInput), TokenCategoryBlock(vocab=labels))


def get_x(item):
    return {"id": item.id, "text": item.proc_tokens}


dblock = DataBlock(blocks=blocks, get_x=get_x, get_y=ColReader("proc_ner_tags"), splitter=RandomSplitter())
Step 3: Build your DataLoaders
dls = dblock.dataloaders(proc_df, bs=4)
b = dls.one_batch()
b[0].keys()
dict_keys(['input_ids', 'attention_mask', 'id', 'labels'])
len(b), b[0]["input_ids"].shape, b[1].shape
(2, torch.Size([4, 130]), torch.Size([4, 130]))
dls.show_batch(dataloaders=dls, max_n=5, trunc_at=20)
word / target label
0 [('MARKET', 'O'), ('TALK', 'O'), ('-', 'O'), ('USDA', 'B-ORG'), ('net', 'O'), ('change', 'O'), ('in', 'O'), ('weekly', 'O'), ('export', 'O'), ('commitments', 'O'), ('for', 'O'), ('the', 'O'), ('week', 'O'), ('ended', 'O'), ('August', 'O'), ('22', 'O'), (',', 'O'), ('includes', 'O'), ('old', 'O'), ('crop', 'O')]
1 [('"', 'O'), ('This', 'O'), ('finding', 'O'), ('is', 'O'), ('important', 'O'), ('because', 'O'), ('one', 'O'), ('of', 'O'), ('the', 'O'), ('jars', 'O'), ('still', 'O'), ('contains', 'O'), ('substances', 'O'), ('and', 'O'), ('materials', 'O'), ('used', 'O'), ('in', 'O'), ('the', 'O'), ('conservation', 'O'), ('of', 'O')]
2 [('"', 'O'), ('We', 'O'), ('have', 'O'), ('always', 'O'), ('been', 'O'), ('concerned', 'O'), ('about', 'O'), ('barter', 'O'), ('deals', 'O'), ('with', 'O'), ('other', 'O'), ('countries', 'O'), (',', 'O'), ('viewing', 'O'), ('them', 'O'), ('as', 'O'), ('a', 'O'), ('disguised', 'O'), ('kind', 'O'), ('of', 'O')]
3 [('The', 'O'), ('officials', 'O'), ('had', 'O'), ('been', 'O'), ('positive', 'O'), ('about', 'O'), ('Kinkel', 'B-PER'), ("'s", 'O'), ('request', 'O'), ('on', 'O'), ('Wednesday', 'O'), ('that', 'O'), ('President', 'O'), ('Boris', 'B-PER'), ('Yeltsin', 'I-PER'), ("'s", 'O'), ('security', 'O'), ('chief', 'O'), ('Alexander', 'B-PER'), ('Lebed', 'I-PER')]

Tests

The tests below to ensure the core DataBlock code above works for all pretrained token classification models available in Hugging Face. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained classification models you are working with … and if any of your pretrained token classification models fail, please submit a github issue (or a PR if you’d like to fix it yourself)

raw_datasets = load_dataset("conll2003")
conll2003_df = pd.DataFrame(raw_datasets["train"])

labels = raw_datasets["train"].features["ner_tags"].feature.names
Reusing dataset conll2003 (/home/wgilliam/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)
arch tokenizer model_name result error
0 albert AlbertTokenizerFast hf-internal-testing/tiny-albert PASSED
1 bert BertTokenizerFast hf-internal-testing/tiny-bert PASSED
2 big_bird BigBirdTokenizerFast google/bigbird-roberta-base PASSED
3 camembert CamembertTokenizerFast camembert-base PASSED
4 convbert ConvBertTokenizerFast YituTech/conv-bert-base PASSED
5 deberta DebertaTokenizerFast hf-internal-testing/tiny-deberta PASSED
6 bert BertTokenizerFast sshleifer/tiny-distilbert-base-cased PASSED
7 electra ElectraTokenizerFast hf-internal-testing/tiny-electra PASSED
8 funnel FunnelTokenizerFast huggingface/funnel-small-base PASSED
9 gpt2 GPT2TokenizerFast sshleifer/tiny-gpt2 PASSED
10 layoutlm LayoutLMTokenizerFast hf-internal-testing/tiny-layoutlm PASSED
11 longformer LongformerTokenizerFast allenai/longformer-base-4096 PASSED
12 mpnet MPNetTokenizerFast microsoft/mpnet-base PASSED
13 ibert RobertaTokenizerFast kssteven/ibert-roberta-base PASSED
14 mobilebert MobileBertTokenizerFast google/mobilebert-uncased PASSED
15 rembert RemBertTokenizerFast google/rembert PASSED
16 roformer RoFormerTokenizerFast junnyu/roformer_chinese_sim_char_ft_small PASSED
17 roberta RobertaTokenizerFast roberta-base PASSED
18 squeezebert SqueezeBertTokenizerFast squeezebert/squeezebert-uncased PASSED
19 xlm_roberta XLMRobertaTokenizerFast xlm-roberta-base PASSED
20 xlnet XLNetTokenizerFast xlnet-base-cased PASSED