This module contains the core seq2seq (e.g., language modeling, summarization, translation) bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data in a way modelable by huggingface transformer implementations.
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, 
                                                                  model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)

Base tokenization, batch transform, and DataBlock methods

Seq2Seq tasks are essentially conditional generation tasks, this applies to specific derived tasks such as summarization and translation. Given this, we can use the same HF_Seq2Seq transforms, HF_Seq2SeqInput, and HF_Seq2SeqBlock for these tasks

class HF_Seq2SeqInput[source]

HF_Seq2SeqInput(x, **kwargs) :: HF_BaseInput

A Tensor which support subclass pickling, and maintains metadata when casting or after methods

We create a subclass of HF_BeforeBatchTransform for summarization tasks to add decoder_input_ids and labels to our inputs during training, which will in turn allow the huggingface model to calculate the loss for us. See here and here for more information on these additional inputs used in summarization, translation, and conversational training tasks. How they should look for particular architectures can be found by looking at those model's forward function's docs (See here for BART for example)

Note also that labels is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) decoder_input_ids.

And lastly, we also update our targets to just be the input_ids of our target sequence so that fastai's Learner.show_results works (again, almost all the fastai bits require returning a single tensor to work).

default_text_gen_kwargs[source]

default_text_gen_kwargs(hf_config, hf_model, task=None)

default_text_gen_kwargs(hf_config, hf_model)
{'max_length': 142,
 'min_length': 56,
 'do_sample': False,
 'early_stopping': True,
 'num_beams': 4,
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'repetition_penalty': 1.0,
 'bad_words_ids': None,
 'bos_token_id': 0,
 'pad_token_id': 1,
 'eos_token_id': 2,
 'length_penalty': 2.0,
 'no_repeat_ngram_size': 3,
 'encoder_no_repeat_ngram_size': 0,
 'num_return_sequences': 1,
 'decoder_start_token_id': 2,
 'use_cache': True,
 'num_beam_groups': 1,
 'diversity_penalty': 0.0,
 'output_attentions': False,
 'output_hidden_states': False,
 'output_scores': False,
 'return_dict_in_generate': False,
 'forced_bos_token_id': 0,
 'forced_eos_token_id': 2,
 'remove_invalid_values': False}

class HF_Seq2SeqBeforeBatchTransform[source]

HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model, ignore_token_id=-100, max_length=None, max_target_length=None, padding=True, truncation=True, tok_kwargs={}, text_gen_kwargs={}, **kwargs) :: HF_BeforeBatchTransform

Handles everything you need to assemble a mini-batch of inputs and targets, as well as decode the dictionary produced as a byproduct of the tokenization process in the encodes method.

We include a new AFTER batch Transform and TransformBlock specific to text-2-text tasks.

class HF_Seq2SeqAfterBatchTransform[source]

HF_Seq2SeqAfterBatchTransform(hf_tokenizer, input_return_type=HF_BaseInput) :: HF_AfterBatchTransform

Delegates (__call__,decode,setup) to (encodes,decodes,setups) if split_idx matches

class HF_Seq2SeqBlock[source]

HF_Seq2SeqBlock(hf_arch=None, hf_config=None, hf_tokenizer=None, hf_model=None, before_batch_tfm=None, after_batch_tfm=None, max_length=None, max_target_length=None, padding=True, truncation=True, input_return_type=HF_Seq2SeqInput, dl_type=SortedDL, tok_kwargs={}, text_gen_kwargs={}, before_batch_kwargs={}, after_batch_kwargs={}, **kwargs) :: HF_TextBlock

A basic wrapper that links defaults transforms for the data block API

... and a DataLoaders.show_batch for seq2seq tasks

Cleanup