blurr
  • Getting Started
  • Resources
    • fastai x Hugging Face Study Group
    • Hugging Face Course
    • fast.ai (docs)
    • transformers (docs)
  • Help
    • Report an Issue

Data

  • Overview
    • Getting Started
    • callbacks
    • utils
  • Text
    • Sequence Classification
      • Data
      • Modeling
    • Token Classification
      • Data
      • Modeling
    • Question & Answering
      • Data
      • Modeling
    • Language Modeling
      • Data
      • Modeling
    • Seq2Seq: Core
      • Data
      • Modeling
    • Seq2Seq: Summarization
      • Data
      • Modeling
    • Seq2Seq: Translation
      • Data
      • Modeling
    • callbacks
    • utils
  • Examples
    • Using the high-level Blurr API
    • GLUE classification tasks
    • Using the Low-level fastai API
    • Multi-label classification
    • Causal Language Modeling with GPT-2

On this page

  • Setup
  • Preprocessing
    • Seq2SeqPreprocessor
  • Mid-level API
    • Seq2SeqTextInput
    • Seq2SeqBatchTokenizeTransform
    • Seq2SeqBatchDecodeTransform
    • default_text_gen_kwargs
    • Seq2SeqTextBlock
    • show_batch

Report an issue

Data

The text.data.seq2seq.core module contains the core seq2seq (e.g., language modeling, summarization, translation) bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data in a way modelable by Hugging Face transformer implementations.

Setup

pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)

Preprocessing

Starting with version 2.0, BLURR provides a preprocessing base class that can be used to build seq2seq preprocessed datasets from pandas DataFrames or Hugging Face Datasets


source

Seq2SeqPreprocessor

 Seq2SeqPreprocessor (hf_tokenizer:transformers.tokenization_utils_base.Pr
                      eTrainedTokenizerBase, batch_size:int=1000,
                      text_attr:str='text',
                      max_input_tok_length:Optional[int]=None,
                      target_text_attr:str='summary',
                      max_target_tok_length:Optional[int]=None,
                      is_valid_attr:Optional[str]='is_valid',
                      tok_kwargs:dict={})

Initialize self. See help(type(self)) for accurate signature.

Type Default Details
hf_tokenizer PreTrainedTokenizerBase A Hugging Face tokenizer
batch_size int 1000 The number of examples to process at a time
text_attr str text The attribute holding the text
max_input_tok_length Optional None The maximum length (# of tokens) allowed for inputs. Will default to the max length allowed
by the model if not provided
target_text_attr str summary The attribute holding the summary
max_target_tok_length Optional None The maximum length (# of tokens) allowed for targets
is_valid_attr Optional is_valid The attribute that should be created if your are processing individual training and validation
datasets into a single dataset, and will indicate to which each example is associated
tok_kwargs dict {} Tokenization kwargs that will be applied with calling the tokenizer

Mid-level API

Base tokenization, batch transform, and DataBlock methods


source

Seq2SeqTextInput

 Seq2SeqTextInput (x, **kwargs)

The base represenation of your inputs; used by the various fastai show methods

A Seq2SeqTextInput object is returned from the decodes method of Seq2SeqBatchTokenizeTransform as a means to customize @typedispatched functions like DataLoaders.show_batch and Learner.show_results. The value will the your “input_ids”.


source

Seq2SeqBatchTokenizeTransform

 Seq2SeqBatchTokenizeTransform (hf_arch:str,
                                hf_config:transformers.configuration_utils
                                .PretrainedConfig, hf_tokenizer:transforme
                                rs.tokenization_utils_base.PreTrainedToken
                                izerBase, hf_model:transformers.modeling_u
                                tils.PreTrainedModel,
                                include_labels:bool=True,
                                ignore_token_id:int=-100,
                                max_length:int=None,
                                max_target_length:int=None,
                                padding:Union[bool,str]=True,
                                truncation:Union[bool,str]=True,
                                is_split_into_words:bool=False,
                                tok_kwargs={}, text_gen_kwargs={},
                                **kwargs)

Handles everything you need to assemble a mini-batch of inputs and targets, as well as decode the dictionary produced as a byproduct of the tokenization process in the encodes method.

Type Default Details
hf_arch str The abbreviation/name of your Hugging Face transformer architecture (e.b., bert, bart, etc..)
hf_config PretrainedConfig A specific configuration instance you want to use
hf_tokenizer PreTrainedTokenizerBase A Hugging Face tokenizer
hf_model PreTrainedModel A Hugging Face model
include_labels bool True To control whether the “labels” are included in your inputs. If they are, the loss will be calculated in
the model’s forward function and you can simply use PreCalculatedLoss as your Learner’s loss function to use it
ignore_token_id int -100 The token ID that should be ignored when calculating the loss
max_length int None To control the length of the padding/truncation of the input sequence. It can be an integer or None,

in which case it will default to the maximum length the model can accept. If the model has no specific maximum input length, truncation/padding to max_length is deactivated. See Everything you always wanted to know about padding and truncation | | max_target_length | int | None | To control the length of the padding/truncation of the target sequence. It can be an integer or None, in which case it will default to the maximum length the model can accept. If the model has no specific maximum input length, truncation/padding to max_length is deactivated. See Everything you always wanted to know about padding and truncation | | padding | Union | True | To control the padding applied to your hf_tokenizer during tokenization. If None, will default to False or 'do_not_pad'. See [Everything you always wanted to know about padding and truncation](https://huggingface.co/transformers/preprocessing.html#everything-you-always-wanted-to-know-about-padding-and-truncation) | | truncation | Union | True | To controltruncationapplied to yourhf_tokenizerduring tokenization. If None, will default toFalseordo_not_truncate. See [Everything you always wanted to know about padding and truncation](https://huggingface.co/transformers/preprocessing.html#everything-you-always-wanted-to-know-about-padding-and-truncation) | | is_split_into_words | bool | False | Theis_split_into_wordsargument applied to yourhf_tokenizerduring tokenization. Set this toTrueif your inputs are pre-tokenized (not numericalized) | | tok_kwargs | dict | {} | Any other keyword arguments you want included when using yourhf_tokenizerto tokenize your inputs | | text_gen_kwargs | dict | {} | Any keyword arguments to pass to thehf_model.generate` method | | kwargs | | | |

We create a subclass of BatchTokenizeTransform for summarization tasks to add decoder_input_ids and labels (if we want Hugging Face to calculate the loss for us) to our inputs during training. See here and here for more information on these additional inputs used in summarization, translation, and conversational training tasks. How they should look for particular architectures can be found by looking at those model’s forward function’s docs (See here for BART for example)

Note also that labels is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) decoder_input_ids.

And lastly, we also update our targets to just be the input_ids of our target sequence so that fastai’s Learner.show_results works (again, almost all the fastai bits require returning a single tensor to work).


source

Seq2SeqBatchDecodeTransform

 Seq2SeqBatchDecodeTransform (input_return_type:type=<class
                              'blurr.text.data.core.TextInput'>,
                              hf_arch:str=None,
                              hf_config:PretrainedConfig=None,
                              hf_tokenizer:PreTrainedTokenizerBase=None,
                              hf_model:PreTrainedModel=None, **kwargs)

A class used to cast your inputs as input_return_type for fastai show methods

Type Default Details
input_return_type type TextInput Used by typedispatched show methods
hf_arch str None The abbreviation/name of your Hugging Face transformer architecture (not required if passing in an instance of BatchTokenizeTransform to before_batch_tfm)
hf_config PretrainedConfig None A Hugging Face configuration object (not required if passing in an instance of BatchTokenizeTransform to before_batch_tfm)
hf_tokenizer PreTrainedTokenizerBase None A Hugging Face tokenizer (not required if passing in an instance of BatchTokenizeTransform to before_batch_tfm)
hf_model PreTrainedModel None A Hugging Face model (not required if passing in an instance of BatchTokenizeTransform to before_batch_tfm)
kwargs

source

default_text_gen_kwargs

 default_text_gen_kwargs (hf_config, hf_model, task=None)
default_text_gen_kwargs(hf_config, hf_model)
{'max_length': 142,
 'min_length': 56,
 'do_sample': False,
 'early_stopping': True,
 'num_beams': 4,
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'repetition_penalty': 1.0,
 'bad_words_ids': None,
 'bos_token_id': 0,
 'pad_token_id': 1,
 'eos_token_id': 2,
 'length_penalty': 2.0,
 'no_repeat_ngram_size': 3,
 'encoder_no_repeat_ngram_size': 0,
 'num_return_sequences': 1,
 'decoder_start_token_id': 2,
 'use_cache': True,
 'num_beam_groups': 1,
 'diversity_penalty': 0.0,
 'output_attentions': False,
 'output_hidden_states': False,
 'output_scores': False,
 'return_dict_in_generate': False,
 'forced_bos_token_id': 0,
 'forced_eos_token_id': 2,
 'remove_invalid_values': False}

source

Seq2SeqTextBlock

 Seq2SeqTextBlock (hf_arch:str=None,
                   hf_config:transformers.configuration_utils.PretrainedCo
                   nfig=None, hf_tokenizer:transformers.tokenization_utils
                   _base.PreTrainedTokenizerBase=None, hf_model:transforme
                   rs.modeling_utils.PreTrainedModel=None, batch_tokenize_
                   tfm:Optional[blurr.text.data.core.BatchTokenizeTransfor
                   m]=None, batch_decode_tfm:Optional[blurr.text.data.core
                   .BatchDecodeTransform]=None, max_length:int=None,
                   max_target_length=None, padding:Union[bool,str]=True,
                   truncation:Union[bool,str]=True,
                   input_return_type=<class '__main__.Seq2SeqTextInput'>,
                   dl_type=<class 'fastai.text.data.SortedDL'>,
                   batch_tokenize_kwargs:dict={},
                   batch_decode_kwargs:dict={}, tok_kwargs={},
                   text_gen_kwargs={}, **kwargs)

The core TransformBlock to prepare your inputs for training in Blurr with fastai’s DataBlock API

Type Default Details
hf_arch str None The abbreviation/name of your Hugging Face transformer architecture (not required if passing in an
instance of BatchTokenizeTransform to before_batch_tfm)
hf_config PretrainedConfig None A Hugging Face configuration object (not required if passing in an
instance of BatchTokenizeTransform to before_batch_tfm)
hf_tokenizer PreTrainedTokenizerBase None A Hugging Face tokenizer (not required if passing in an
instance of BatchTokenizeTransform to before_batch_tfm)
hf_model PreTrainedModel None A Hugging Face model (not required if passing in an
instance of BatchTokenizeTransform to before_batch_tfm)
batch_tokenize_tfm Optional None The before_batch_tfm you want to use to tokenize your raw data on the fly
(defaults to an instance of BatchTokenizeTransform)
batch_decode_tfm Optional None The batch_tfm you want to decode your inputs into a type that can be used in the fastai show methods,
(defaults to BatchDecodeTransform)
max_length int None To control the length of the padding/truncation for the input sequence. It can be an integer or None,

in which case it will default to the maximum length the model can accept. If the model has no specific maximum input length, truncation/padding to max_length is deactivated. See Everything you always wanted to know about padding and truncation | | max_target_length | NoneType | None | To control the length of the padding/truncation for the target sequence. It can be an integer or None, in which case it will default to the maximum length the model can accept. If the model has no specific maximum input length, truncation/padding to max_length is deactivated. See [Everything you always wanted to know about padding and truncation](https://huggingface.co/transformers/preprocessing.html#everything-y | | padding | Union | True | To control the padding applied to your hf_tokenizer during tokenization. If None, will default to False or 'do_not_pad'. See [Everything you always wanted to know about padding and truncation](https://huggingface.co/transformers/preprocessing.html#everything-you-always-wanted-to-know-about-padding-and-truncation) | | truncation | Union | True | To controltruncationapplied to yourhf_tokenizerduring tokenization. If None, will default toFalseordo_not_truncate. See [Everything you always wanted to know about padding and truncation](https://huggingface.co/transformers/preprocessing.html#everything-you-always-wanted-to-know-about-padding-and-truncation) | | input_return_type | _TensorMeta | Seq2SeqTextInput | The return type your decoded inputs should be cast too (used by methods such asshow_batch) | | dl_type | type | SortedDL | The type ofDataLoaderyou want created (defaults toSortedDL) | | batch_tokenize_kwargs | dict | {} | Any keyword arguments you want applied to yourbatch_tokenize_tfm| | batch_decode_kwargs | dict | {} | Any keyword arguments you want applied to yourbatch_decode_tfm(will be set as a fastaibatch_tfms`) | | tok_kwargs | dict | {} | Any keyword arguments you want your Hugging Face tokenizer to use during tokenization | | text_gen_kwargs | dict | {} | Any keyword arguments you want to have applied with generating text (default: default_text_gen_kwargs) | | kwargs | | | |

show_batch