blurr
  • Getting Started
  • Resources
    • fastai x Hugging Face Study Group
    • Hugging Face Course
    • fast.ai (docs)
    • transformers (docs)
  • Help
    • Report an Issue

Data

  • Overview
    • Getting Started
    • callbacks
    • utils
  • Text
    • Sequence Classification
      • Data
      • Modeling
    • Token Classification
      • Data
      • Modeling
    • Question & Answering
      • Data
      • Modeling
    • Language Modeling
      • Data
      • Modeling
    • Seq2Seq: Core
      • Data
      • Modeling
    • Seq2Seq: Summarization
      • Data
      • Modeling
    • Seq2Seq: Translation
      • Data
      • Modeling
    • callbacks
    • utils
  • Examples
    • Using the high-level Blurr API
    • GLUE classification tasks
    • Using the Low-level fastai API
    • Multi-label classification
    • Causal Language Modeling with GPT-2

On this page

  • Setup
  • Preprocessing
    • TranslationPreprocessor
      • Using a DataFrame
  • Examples
    • Using the mid-level API
      • Batch-Time Tokenization
      • Using a preprocessed dataset
  • Tests

Report an issue

Data

The text.data.seq2seq.translation module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for translation tasks

Setup

We’ll use a subset of wmt16 to demonstrate how to configure your BLURR for translation tasks

raw_dataset = load_dataset("wmt16", "de-en", split="train[:1%]")
raw_dataset
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f)
Dataset({
    features: ['translation'],
    num_rows: 45489
})
print(raw_dataset[0].keys())
print(raw_dataset[0])
dict_keys(['translation'])
{'translation': {'de': 'Wiederaufnahme der Sitzungsperiode', 'en': 'Resumption of the session'}}
wmt_df = pd.DataFrame(raw_dataset["translation"], columns=["de", "en"])

print(len(wmt_df))
wmt_df.head(2)
45489
de en
0 Wiederaufnahme der Sitzungsperiode Resumption of the session
1 Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten. I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
('bart',
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)

Preprocessing

Starting with version 2.0, BLURR provides a preprocessing base class that can be used to build task specific pre-processed datasets from pandas DataFrames or Hugging Face Datasets


source

TranslationPreprocessor

 TranslationPreprocessor (hf_tokenizer:transformers.tokenization_utils_bas
                          e.PreTrainedTokenizerBase, batch_size:int=1000,
                          id_attr:Optional[str]=None,
                          text_attr:str='original_text',
                          max_input_tok_length:Optional[int]=None,
                          target_text_attr:str='translated_text',
                          max_target_tok_length:Optional[int]=None,
                          is_valid_attr:Optional[str]='is_valid',
                          tok_kwargs:dict={})

Initialize self. See help(type(self)) for accurate signature.

Type Default Details
hf_tokenizer PreTrainedTokenizerBase A Hugging Face tokenizer
batch_size int 1000 The number of examples to process at a time
id_attr Optional None The unique identifier in the dataset
text_attr str original_text The attribute holding the text to translate
max_input_tok_length Optional None The maximum length (# of tokens) allowed for inputs. Will default to the max length allowed
by the model if not provided
target_text_attr str translated_text The attribute holding the summary
max_target_tok_length Optional None The maximum length (# of tokens) allowed for targets
is_valid_attr Optional is_valid The attribute that should be created if your are processing individual training and validation
datasets into a single dataset, and will indicate to which each example is associated
tok_kwargs dict {} Tokenization kwargs that will be applied with calling the tokenizer

This class can be used for preprocessing translation tasks, and includes a proc_{your_text_attr} and proc_{target_text_attr} attributes containing your modified input and target texts as a result of tokenization (e.g., if you specify a max_length the proc_{your_text_attr} may contain truncated text).

Using a DataFrame

preprocessor = TranslationPreprocessor(
    hf_tokenizer, text_attr="de", target_text_attr="en", max_input_tok_length=128, max_target_tok_length=128
)
proc_df = preprocessor.process_df(wmt_df)
proc_df.columns, len(proc_df)
proc_df.head(2)
proc_en proc_de de en de_start_char_idx de_end_char_idx en_start_char_idx en_end_char_idx
0 Resumption of the session Wiederaufnahme der Sitzungsperiode Wiederaufnahme der Sitzungsperiode Resumption of the session 0 34 0 25
1 I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period. Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten. Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten. I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period. 0 218 0 207

Examples

Using the mid-level API

Batch-Time Tokenization

Step 1: Get your Hugging Face objects.
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
Step 2: Create your DataBlock

Two lines! Notice we pass in noop for our targets (e.g. our summaries) because the batch transform will take care of both out inputs and targets.

blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("de"), get_y=ColReader("en"), splitter=RandomSplitter())
# dblock.summary(wmt_df)
Step 3: Build your DataLoaders
dls = dblock.dataloaders(wmt_df, bs=4)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[0]["labels"].shape, b[1].shape
(2, torch.Size([4, 483]), torch.Size([4, 86]), torch.Size([4, 86]))
b[0]["labels"][0], b[1][0]
(tensor([    0,  2223,    13,  1402,  4723,    89,   189,    28, 16855,   111,
            38,   524,  2053,  4010,     9,     5,  2788,  4755,  1293,     6,
           147,     5,  1492,     9,  9813,   696,  4685,   372,  2212,   111,
             5,  3038,    40,    28, 10142,    13,   258,     5,   796,  1332,
             8,  1625,     4,   286,     5,   796,  1332,     6,   142,     5,
          7147,     9,    10,   481,   721,   443,    40,  3155,    24,     7,
          9648,     5,  2621, 10153,   532,    56,    11,  4938,  1048,   137,
             5, 13783,  1288,   376,    88,  1370,     6,  3329,    92,  2919,
          1616,    13,   796,   451,     4,     2], device='cuda:1'),
 tensor([    0,  2223,    13,  1402,  4723,    89,   189,    28, 16855,   111,
            38,   524,  2053,  4010,     9,     5,  2788,  4755,  1293,     6,
           147,     5,  1492,     9,  9813,   696,  4685,   372,  2212,   111,
             5,  3038,    40,    28, 10142,    13,   258,     5,   796,  1332,
             8,  1625,     4,   286,     5,   796,  1332,     6,   142,     5,
          7147,     9,    10,   481,   721,   443,    40,  3155,    24,     7,
          9648,     5,  2621, 10153,   532,    56,    11,  4938,  1048,   137,
             5, 13783,  1288,   376,    88,  1370,     6,  3329,    92,  2919,
          1616,    13,   796,   451,     4,     2], device='cuda:1'))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
text target
0 <s> Was nun die Ergebnisse der Verhandlungen über die Anwendung der Artikel 3, 4, 5, 6 und 12 des Interimsabkommens bezüglich Warenhandel, öffentlicher Aufträge, Wettbewerb, Konsultationsmechanismen bei Fragen des geistigen Eigentums und Beilegung vo Although for certain sectors there may be flaws - I am thinking specifically of the textiles sector, where the rules of origin issue causes great concern - the effects will be beneficial for both the European Union and Mexico. For the European Union
1 <s> Die allgemeine Ausrichtung der umgesetzten Wirtschaftspolitik, der Stabilitätspakt sowie die strengen Konvergenzprogramme, die der Beschäftigung empfindlich schaden und Beschäftigungsfähigkeit sowie Flexibilität der Arbeitsverhältnisse und Arbeit The general lines of the implemented economic policy, the Stability Pact and the strict convergence programmes, which are a constant menace to employment and which promote employability and the flexibilisation of labour relations, and the organisati

Using a preprocessed dataset

Step 1a: Get your Hugging Face objects.
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
Step 1b. Preprocess dataset
preprocessor = TranslationPreprocessor(
    hf_tokenizer, text_attr="de", target_text_attr="en", max_input_tok_length=128, max_target_tok_length=128
)
proc_df = preprocessor.process_df(wmt_df)
Step 2: Create your DataBlock
blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("proc_de"), get_y=ColReader("proc_en"), splitter=RandomSplitter())
Step 3: Build your DataLoaders
dls = dblock.dataloaders(proc_df, bs=4)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[0]["labels"].shape, b[1].shape
(2, torch.Size([4, 129]), torch.Size([4, 83]), torch.Size([4, 83]))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
text target
0 <s> Hierbei denke ich an systematische Dokumentation und Informationsbeschaffung, professionelle Formen der Beobachtung, die Entwicklung von Aufklärungsaktionen, die Verwendung von Geldern zur Unterstützung der demokratischen Kräfte in dem betreffend The following spring to mind in this respect: the systematic collation of documentation and information, professional forms of observation, the development of information campaigns, the use of cash to support democratic forces in the country concern
1 <s> Eine letzte Bemerkung: Die durch die Struktur des Internet bedingte permanente Verfügbarkeit von Sexuellem im ausschließlich anonymisierten privaten Bereich und die Tatsache, daß der sexuelle Mißbrauch der öffentlichen und damit der sozialen Kont One final comment: the permanent availability of sexual material in the exclusively anonymous private sphere, which is conditioned by the structure of the Internet, and the attendant fact that sexual abuse is removed from public and, hence, social c

Tests

The purpose of the following tests is to ensure as much as possible, that the core DataBlock code above works for the pretrained translation models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained translation models you are working with … and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you’d like to fix it yourself)

[model_type for model_type in NLP.get_models(task="ConditionalGeneration") if (not model_type.startswith("TF"))]
['BartForConditionalGeneration',
 'BigBirdPegasusForConditionalGeneration',
 'BlenderbotForConditionalGeneration',
 'BlenderbotSmallForConditionalGeneration',
 'FSMTForConditionalGeneration',
 'LEDForConditionalGeneration',
 'M2M100ForConditionalGeneration',
 'MBartForConditionalGeneration',
 'MT5ForConditionalGeneration',
 'PegasusForConditionalGeneration',
 'ProphetNetForConditionalGeneration',
 'Speech2TextForConditionalGeneration',
 'T5ForConditionalGeneration',
 'XLMProphetNetForConditionalGeneration']
pretrained_model_names = [
    "facebook/bart-base",
    "facebook/wmt19-de-en",  # FSMT
    "Helsinki-NLP/opus-mt-de-en",  # MarianMT
    "sshleifer/tiny-mbart",
    "google/mt5-small",
    "t5-small",
]
path = Path("./")
wmt_df = pd.DataFrame(raw_dataset["translation"], columns=["de", "en"])
model_cls = AutoModelForSeq2SeqLM
bsz = 2
seq_sz = 128
trg_seq_sz = 128

test_results = []
for model_name in pretrained_model_names:
    error = None

    print(f"=== {model_name} ===\n")

    hf_tok_kwargs = {}
    if model_name == "sshleifer/tiny-mbart":
        hf_tok_kwargs["src_lang"], hf_tok_kwargs["tgt_lang"] = "de_DE", "en_XX"

    hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(model_name, model_cls=model_cls, tokenizer_kwargs=hf_tok_kwargs)

    print(f"architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n")

    # not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
    if hf_tokenizer.pad_token is None:
        hf_tokenizer.add_special_tokens({"pad_token": "<pad>"})
        hf_config.pad_token_id = hf_tokenizer.get_vocab()["<pad>"]
        hf_model.resize_token_embeddings(len(hf_tokenizer))

    batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
        hf_arch, hf_config, hf_tokenizer, hf_model, padding="max_length", max_length=seq_sz, max_target_length=trg_seq_sz
    )

    def add_t5_prefix(inp):
        return f"translate German to English: {inp}" if (hf_arch == "t5") else inp

    blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)
    dblock = DataBlock(blocks=blocks, get_x=Pipeline([ColReader("de"), add_t5_prefix]), get_y=ColReader("en"), splitter=RandomSplitter())

    dls = dblock.dataloaders(wmt_df, bs=bsz)
    b = dls.one_batch()

    try:
        print("*** TESTING DataLoaders ***\n")
        test_eq(len(b), 2)
        test_eq(len(b[0]["input_ids"]), bsz)
        test_eq(b[0]["input_ids"].shape, torch.Size([bsz, seq_sz]))
        test_eq(len(b[1]), bsz)
        test_eq(b[1].shape, torch.Size([bsz, trg_seq_sz]))

        if hasattr(hf_tokenizer, "add_prefix_space"):
            test_eq(hf_tokenizer.add_prefix_space, True)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, "PASSED", ""))
        dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000)

    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, "FAILED", err))
arch tokenizer model_name result error
0 bart BartTokenizerFast facebook/bart-base PASSED
1 fsmt FSMTTokenizer facebook/wmt19-de-en PASSED
2 marian MarianTokenizer Helsinki-NLP/opus-mt-de-en PASSED
3 mbart MBartTokenizerFast sshleifer/tiny-mbart PASSED
4 mt5 T5TokenizerFast google/mt5-small PASSED
5 t5 T5TokenizerFast t5-small PASSED