blurr
  • Getting Started
  • Resources
    • fastai x Hugging Face Study Group
    • Hugging Face Course
    • fast.ai (docs)
    • transformers (docs)
  • Help
    • Report an Issue

Modeling

  • Overview
    • Getting Started
    • callbacks
    • utils
  • Text
    • Sequence Classification
      • Data
      • Modeling
    • Token Classification
      • Data
      • Modeling
    • Question & Answering
      • Data
      • Modeling
    • Language Modeling
      • Data
      • Modeling
    • Seq2Seq: Core
      • Data
      • Modeling
    • Seq2Seq: Summarization
      • Data
      • Modeling
    • Seq2Seq: Translation
      • Data
      • Modeling
    • callbacks
    • utils
  • Examples
    • Using the high-level Blurr API
    • GLUE classification tasks
    • Using the Low-level fastai API
    • Multi-label classification
    • Causal Language Modeling with GPT-2

On this page

  • Mid-level API
    • Prepare the data
    • Example
      • Training
      • Showing results
      • Prediction
    • Learner.blurr_translate
      • Inference
  • High-level API
    • BlearnerForTranslation
  • Tests

Report an issue

Modeling

The text.modeling.seq2seq.translation module contains custom models, custom splitters, etc… translation tasks.

Mid-level API

Prepare the data

Example

The objective in translation is to generate a representation of a given text in another style. For example, we may want to translate German into English or modern English into old English.

dataset = load_dataset("wmt16", "de-en", split="train")
dataset = dataset.shuffle(seed=32).select(range(1200))
wmt_df = pd.DataFrame(dataset["translation"], columns=["de", "en"])
len(wmt_df)
wmt_df.head(2)
Downloading and preparing dataset wmt16/de-en (download: 1.57 GiB, generated: 1.28 GiB, post-processed: Unknown size, total: 2.85 GiB) to /home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f...
Generating examples from: %s europarl_v7
Generating examples from: %s commoncrawl
Generating examples from: %s newscommentary_v11
Generating examples from: %s newstest2015
Generating examples from: %s newstest2016
Dataset wmt16 downloaded and prepared to /home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f. Subsequent calls will reuse this data.
de en
0 Tada se dio stanovništva preselio uz samu obalu - Pristan, gdje je i nastao Novi grad početkom XX vijeka. In that period the majority of the population moved close to the seaside, where the first sea port was founded at the beginning of the 20th century, and later a new city was built.
1 "Dieses Video ist nicht verfügbar loger" bitch, daß das Böse, der sein Video auf YouTube hochgeladen hatte nearsyx? "This video is no loger available" that evil bitch, who had uploaded his video on youtube nearsyx?
pretrained_model_name = "Helsinki-NLP/opus-mt-de-en"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=model_cls)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json not found in cache or force_download set to True, downloading to /home/wgilliam/.cache/huggingface/transformers/tmppp9i08el
storing https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json in cache at /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
creating metadata file for /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
loading configuration file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
Model config MarianConfig {
  "_name_or_path": "Helsinki-NLP/opus-mt-de-en",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      58100
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 58100,
  "decoder_vocab_size": 58101,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "forced_eos_token_id": 0,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 512,
  "max_position_embeddings": 512,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 58100,
  "scale_embedding": true,
  "share_encoder_decoder_embeddings": true,
  "static_position_embeddings": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 58101
}

https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/tokenizer_config.json not found in cache or force_download set to True, downloading to /home/wgilliam/.cache/huggingface/transformers/tmpes7wybva
storing https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/tokenizer_config.json in cache at /home/wgilliam/.cache/huggingface/transformers/3bb44a3386cfbb9cb18134066610daf2447a07f2f56a14bed4ef1ffee714851c.ab636688faaa6513d9a830ea57bdb7081f0dda90f9de5e3c857a239f0cc406e7
creating metadata file for /home/wgilliam/.cache/huggingface/transformers/3bb44a3386cfbb9cb18134066610daf2447a07f2f56a14bed4ef1ffee714851c.ab636688faaa6513d9a830ea57bdb7081f0dda90f9de5e3c857a239f0cc406e7
loading configuration file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
Model config MarianConfig {
  "_name_or_path": "Helsinki-NLP/opus-mt-de-en",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      58100
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 58100,
  "decoder_vocab_size": 58101,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "forced_eos_token_id": 0,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 512,
  "max_position_embeddings": 512,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 58100,
  "scale_embedding": true,
  "share_encoder_decoder_embeddings": true,
  "static_position_embeddings": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 58101
}

https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/source.spm not found in cache or force_download set to True, downloading to /home/wgilliam/.cache/huggingface/transformers/tmpmmic75d1
storing https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/source.spm in cache at /home/wgilliam/.cache/huggingface/transformers/97f9ac1f9bf6b0e421cdf322cd4243cf20650839545200bf6b513ad03c168c8c.7bc2908774e59068751778d82930d24fe5b81375f4e06aa8f2a62298103c9587
creating metadata file for /home/wgilliam/.cache/huggingface/transformers/97f9ac1f9bf6b0e421cdf322cd4243cf20650839545200bf6b513ad03c168c8c.7bc2908774e59068751778d82930d24fe5b81375f4e06aa8f2a62298103c9587
https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/target.spm not found in cache or force_download set to True, downloading to /home/wgilliam/.cache/huggingface/transformers/tmp_780e_84
storing https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/target.spm in cache at /home/wgilliam/.cache/huggingface/transformers/1c5dd1c09c6117b6da35a0bfc70dee4e4852bd9f1e019474ccd80f98014806b5.5ff349d0044d463eca29fbb3a3d21a2dd0511ced746d6c6941daa893faf53d79
creating metadata file for /home/wgilliam/.cache/huggingface/transformers/1c5dd1c09c6117b6da35a0bfc70dee4e4852bd9f1e019474ccd80f98014806b5.5ff349d0044d463eca29fbb3a3d21a2dd0511ced746d6c6941daa893faf53d79
https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/vocab.json not found in cache or force_download set to True, downloading to /home/wgilliam/.cache/huggingface/transformers/tmp9veehttl
storing https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/vocab.json in cache at /home/wgilliam/.cache/huggingface/transformers/135ba2ed81322da617731039edec94c1b10b121b5499ea1bcdd7e60040cf4913.fe9bdbcb654d47ed6918ebaad81166b879fd0bc12ea76a2cc54359202fa854d7
creating metadata file for /home/wgilliam/.cache/huggingface/transformers/135ba2ed81322da617731039edec94c1b10b121b5499ea1bcdd7e60040cf4913.fe9bdbcb654d47ed6918ebaad81166b879fd0bc12ea76a2cc54359202fa854d7
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/source.spm from cache at /home/wgilliam/.cache/huggingface/transformers/97f9ac1f9bf6b0e421cdf322cd4243cf20650839545200bf6b513ad03c168c8c.7bc2908774e59068751778d82930d24fe5b81375f4e06aa8f2a62298103c9587
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/target.spm from cache at /home/wgilliam/.cache/huggingface/transformers/1c5dd1c09c6117b6da35a0bfc70dee4e4852bd9f1e019474ccd80f98014806b5.5ff349d0044d463eca29fbb3a3d21a2dd0511ced746d6c6941daa893faf53d79
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/vocab.json from cache at /home/wgilliam/.cache/huggingface/transformers/135ba2ed81322da617731039edec94c1b10b121b5499ea1bcdd7e60040cf4913.fe9bdbcb654d47ed6918ebaad81166b879fd0bc12ea76a2cc54359202fa854d7
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/target_vocab.json from cache at None
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/tokenizer_config.json from cache at /home/wgilliam/.cache/huggingface/transformers/3bb44a3386cfbb9cb18134066610daf2447a07f2f56a14bed4ef1ffee714851c.ab636688faaa6513d9a830ea57bdb7081f0dda90f9de5e3c857a239f0cc406e7
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/special_tokens_map.json from cache at None
loading configuration file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
Model config MarianConfig {
  "_name_or_path": "Helsinki-NLP/opus-mt-de-en",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      58100
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 58100,
  "decoder_vocab_size": 58101,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "forced_eos_token_id": 0,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 512,
  "max_position_embeddings": 512,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 58100,
  "scale_embedding": true,
  "share_encoder_decoder_embeddings": true,
  "static_position_embeddings": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 58101
}

https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/pytorch_model.bin not found in cache or force_download set to True, downloading to /home/wgilliam/.cache/huggingface/transformers/tmp4evxfupq
storing https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/pytorch_model.bin in cache at /home/wgilliam/.cache/huggingface/transformers/939fa8e38fdeb206b841054406fe90638dbe4a602679798fc35126e90fe54e12.9f2385d4ebdde4e5e8ef144654a4666f40c8423a85f51590fecb88452aec1514
creating metadata file for /home/wgilliam/.cache/huggingface/transformers/939fa8e38fdeb206b841054406fe90638dbe4a602679798fc35126e90fe54e12.9f2385d4ebdde4e5e8ef144654a4666f40c8423a85f51590fecb88452aec1514
loading weights file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/pytorch_model.bin from cache at /home/wgilliam/.cache/huggingface/transformers/939fa8e38fdeb206b841054406fe90638dbe4a602679798fc35126e90fe54e12.9f2385d4ebdde4e5e8ef144654a4666f40c8423a85f51590fecb88452aec1514
All model checkpoint weights were used when initializing MarianMTModel.

All the weights of MarianMTModel were initialized from the model checkpoint at Helsinki-NLP/opus-mt-de-en.
If your task is similar to the task the model of the checkpoint was trained on, you can already use MarianMTModel for predictions without further training.
('marian',
 transformers.models.marian.tokenization_marian.MarianTokenizer,
 transformers.models.marian.configuration_marian.MarianConfig,
 transformers.models.marian.modeling_marian.MarianMTModel)
blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("de"), get_y=ColReader("en"), splitter=RandomSplitter())
dls = dblock.dataloaders(wmt_df, bs=2)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[1].shape
(2, torch.Size([2, 168]), torch.Size([2, 140]))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
text target
0 "In▁Erwägung▁nachstehender▁Gründe▁sollte das▁Europäische▁Parlament▁keinerlei▁Doppelmoral tolerieren. Indessen und um▁politischen Druck auf▁Journalisten▁auszuüben, die▁Korruptionsfälle aufdecken, die in▁Verbindung mit▁hochrangigen▁Beamten und▁regieren 'whereas the European Parliament shall not accept double standards; whereas, in order to put political pressure on journalists disclosing corruption cases linked to high-ranking officials and ruling party politicians, the Government administration in
1 Es▁ist▁jetzt▁wirklich an der Zeit,▁daß nicht▁nur in▁bezug auf den▁Jahreswirtschaftsbericht und die▁wirtschaftspolitischen▁Leitlinien,▁nein,▁auch in▁bezug auf die▁gesamten▁Fragen zum▁Verfahren zur▁Feststellung des▁übermäßigen▁Defizits und▁auch in▁bezu It really is time for the European Parliament to be given a codecision right that is consistent with the further democratic development of this European Union; that right must apply not just to the annual economic report and the economic policy guide

Training

seq2seq_metrics = {"bleu": {"returns": "bleu"}, "meteor": {"returns": "meteor"}, "sacrebleu": {"returns": "score"}}

model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(
    dls,
    model,
    opt_func=partial(Adam),
    loss_func=PreCalculatedCrossEntropyLoss(),  # CrossEntropyLossFlat()
    cbs=learn_cbs,
    splitter=partial(blurr_seq2seq_splitter, arch=hf_arch),
)

# learn = learn.to_native_fp16() #.to_fp16()
learn.freeze()
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/wgilliam/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/wgilliam/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
learn.summary()
b = dls.one_batch()
preds = learn.model(b[0])

len(preds), preds["loss"].shape, preds["logits"].shape
(3, torch.Size([]), torch.Size([2, 140, 58101]))
len(b), len(b[0]), b[0]["input_ids"].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 168]), 2, torch.Size([2, 140]))
print(len(learn.opt.param_groups))
3
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
SuggestedLRs(minimum=3.981071640737355e-05, steep=6.309573450380412e-07, valley=5.248074739938602e-05, slide=7.585775892948732e-05)

learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss bleu meteor sacrebleu time
0 2.088453 2.097524 0.295524 0.543777 28.882930 00:58

Showing results

And here we create a @typedispatched implementation of Learner.show_results.

learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=500)
text target prediction
0 ▁Schließen die▁vorgeschlagenen▁Anwendungszwecke▁Empfehlungen über die▁Bekämpfung von oder den▁Schutz▁gegen▁Organismen ein, die▁unter den in der▁vorgesehenen▁Anwendungsregion▁herrschenden▁Bedingungen in▁bezug auf▁Landwirtschaft,▁Pflanzenschutz und Umwelt -▁einschließlich der▁Witterungsverhältnisse - nach den▁Erfahrungen und dem▁wissenschaftlichen▁Erkenntnisstand nicht▁als▁schädlich▁gelten, oder▁ist▁davon▁auszugehen,▁daß die▁anderen▁Wirkungen▁unter▁diesen▁Bedingungen den▁beabsichtigten▁Zweck nicht Where relevant, yield response when the product is used and reduction of loss in storage must be quantitatively and/or qualitatively similar to those resulting from the use of suitable reference products. If no suitable reference product exists, the plant protection product must be shown to give a consistent and defined quantitative and/or qualitative benefit in terms of yield response and reduction of loss in storage under the agricultural, plant health and environmental (including climatic) co [Where the proposed uses include recommendations on the control of or protection against organisms which are not considered to be harmful under the conditions prevailing in the intended application region in respect of agriculture, plant health and the environment, including climatic conditions, in the light of experience and scientific knowledge, or where it is assumed that the other effects do not meet the intended purpose under such conditions, no authorisation shall be granted for such uses., That is why we have listened to you and asked you to introduce a further transparent consultation procedure on the Anti-Counterfeiting Agreement (ACTA) to ensure that the European Parliament and the citizens represented by this Parliament are regularly and comprehensively informed about the progress of the negotiations, while respecting the confidentiality clauses that you have just explained to us about the agreement.]

Prediction

We add here Learner.blurr_translate method to bring the results inline with the format returned via Hugging Face’s pipeline method

test_de = "Ich trinke gerne Bier"
outputs = learn.blurr_generate(test_de, key="translation_texts", num_return_sequences=3)
outputs
[{'translation_texts': ['I like to drink beer',
   'I like to drink beer.',
   'I like drinking beer']}]

source

Learner.blurr_translate

 Learner.blurr_translate (inp, **kwargs)
learn.blurr_translate(test_de, num_return_sequences=3)
[{'translation_texts': ['I like to drink beer',
   'I like to drink beer.',
   'I like drinking beer']}]

Inference

Using fast.ai Learner.export and load_learner

export_fname = "translation_export"
learn.metrics = None
learn.export(fname=f"{export_fname}.pkl")
inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_translate(test_de)
[{'translation_texts': 'I like to drink beer'}]

High-level API


source

BlearnerForTranslation

 BlearnerForTranslation (dls:fastai.data.core.DataLoaders,
                         hf_model:transformers.modeling_utils.PreTrainedMo
                         del, base_model_cb:blurr.text.modeling.core.BaseM
                         odelCallback=<class
                         'blurr.text.modeling.core.BaseModelCallback'>,
                         loss_func:callable|None=None, opt_func=<function
                         Adam>, lr=0.001, splitter:callable=<function
                         trainable_params>, cbs=None, metrics=None,
                         path=None, model_dir='models', wd=None,
                         wd_bn_bias=False, train_bn=True, moms=(0.95,
                         0.85, 0.95), default_cbs:bool=True)

Group together a model, some dls and a loss_func to handle training

Type Details
dls DataLoaders containing data for each dataset needed for model
hf_model PreTrainedModel
learn = BlearnerForTranslation.from_data(
    wmt_df,
    "Helsinki-NLP/opus-mt-de-en",
    src_lang_name="German",
    src_lang_attr="de",
    trg_lang_name="English",
    trg_lang_attr="en",
    dl_kwargs={"bs": 2},
)
loading configuration file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
Model config MarianConfig {
  "_name_or_path": "Helsinki-NLP/opus-mt-de-en",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      58100
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 58100,
  "decoder_vocab_size": 58101,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "forced_eos_token_id": 0,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 512,
  "max_position_embeddings": 512,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 58100,
  "scale_embedding": true,
  "share_encoder_decoder_embeddings": true,
  "static_position_embeddings": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 58101
}

loading weights file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/pytorch_model.bin from cache at /home/wgilliam/.cache/huggingface/transformers/939fa8e38fdeb206b841054406fe90638dbe4a602679798fc35126e90fe54e12.9f2385d4ebdde4e5e8ef144654a4666f40c8423a85f51590fecb88452aec1514
All model checkpoint weights were used when initializing MarianMTModel.

All the weights of MarianMTModel were initialized from the model checkpoint at Helsinki-NLP/opus-mt-de-en.
If your task is similar to the task the model of the checkpoint was trained on, you can already use MarianMTModel for predictions without further training.
loading configuration file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
Model config MarianConfig {
  "_name_or_path": "Helsinki-NLP/opus-mt-de-en",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      58100
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 58100,
  "decoder_vocab_size": 58101,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "forced_eos_token_id": 0,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 512,
  "max_position_embeddings": 512,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 58100,
  "scale_embedding": true,
  "share_encoder_decoder_embeddings": true,
  "static_position_embeddings": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 58101
}

loading configuration file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
Model config MarianConfig {
  "_name_or_path": "Helsinki-NLP/opus-mt-de-en",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      58100
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 58100,
  "decoder_vocab_size": 58101,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "forced_eos_token_id": 0,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 512,
  "max_position_embeddings": 512,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 58100,
  "scale_embedding": true,
  "share_encoder_decoder_embeddings": true,
  "static_position_embeddings": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 58101
}

loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/source.spm from cache at /home/wgilliam/.cache/huggingface/transformers/97f9ac1f9bf6b0e421cdf322cd4243cf20650839545200bf6b513ad03c168c8c.7bc2908774e59068751778d82930d24fe5b81375f4e06aa8f2a62298103c9587
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/target.spm from cache at /home/wgilliam/.cache/huggingface/transformers/1c5dd1c09c6117b6da35a0bfc70dee4e4852bd9f1e019474ccd80f98014806b5.5ff349d0044d463eca29fbb3a3d21a2dd0511ced746d6c6941daa893faf53d79
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/vocab.json from cache at /home/wgilliam/.cache/huggingface/transformers/135ba2ed81322da617731039edec94c1b10b121b5499ea1bcdd7e60040cf4913.fe9bdbcb654d47ed6918ebaad81166b879fd0bc12ea76a2cc54359202fa854d7
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/target_vocab.json from cache at None
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/tokenizer_config.json from cache at /home/wgilliam/.cache/huggingface/transformers/3bb44a3386cfbb9cb18134066610daf2447a07f2f56a14bed4ef1ffee714851c.ab636688faaa6513d9a830ea57bdb7081f0dda90f9de5e3c857a239f0cc406e7
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/special_tokens_map.json from cache at None
loading configuration file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/config.json from cache at /home/wgilliam/.cache/huggingface/transformers/1854c5c3f3aeab11cfc4ef9f74e960e7bf2300332cd7cdbd83077f02499cdfab.b1412cdfcd82522fbf1b1559d2bb133e7c34f871e99859d46b74f1533daa4757
Model config MarianConfig {
  "_name_or_path": "Helsinki-NLP/opus-mt-de-en",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      58100
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 58100,
  "decoder_vocab_size": 58101,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "forced_eos_token_id": 0,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 512,
  "max_position_embeddings": 512,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 58100,
  "scale_embedding": true,
  "share_encoder_decoder_embeddings": true,
  "static_position_embeddings": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 58101
}

loading weights file https://huggingface.co/Helsinki-NLP/opus-mt-de-en/resolve/main/pytorch_model.bin from cache at /home/wgilliam/.cache/huggingface/transformers/939fa8e38fdeb206b841054406fe90638dbe4a602679798fc35126e90fe54e12.9f2385d4ebdde4e5e8ef144654a4666f40c8423a85f51590fecb88452aec1514
All model checkpoint weights were used when initializing MarianMTModel.

All the weights of MarianMTModel were initialized from the model checkpoint at Helsinki-NLP/opus-mt-de-en.
If your task is similar to the task the model of the checkpoint was trained on, you can already use MarianMTModel for predictions without further training.
metrics_cb = BlearnerForTranslation.get_metrics_cb()
learn.fit_one_cycle(1, lr_max=4e-5, cbs=[metrics_cb])
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/wgilliam/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/wgilliam/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
epoch train_loss valid_loss bleu meteor sacrebleu time
0 2.014099 2.172663 0.307334 0.537191 29.823626 00:52
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (IT) Herr▁Präsident, Herr▁Kommissar,▁meine▁Damen und Herren, so▁genau▁wie die▁Entschließung mit dem▁Titel "Naturkatastrophen", die von der▁Fraktion der▁Europäischen▁Volkspartei (Christdemokraten)▁vorgelegt wurde,▁ist,▁würde▁ich▁gerne▁trotzdem die▁Aufmerksamkeit auf▁einige▁Punkte▁lenken, die▁heute▁Abend▁angesprochen▁wurden, die▁aber nicht in der▁Entschließung zum▁Thema▁gemacht▁werden, und die▁Gegenstand▁meiner▁Änderungsvorschläge▁sind. (IT) Mr President, Commissioner, ladies and gentlemen, as accurate as the resolution entitled 'Natural disasters', tabled by the Group of the European People's Party (Christian Democrats), is, I would nonetheless like to draw attention to some points [(IT) Mr President, Commissioner, ladies and gentlemen, just as the resolution on natural disasters presented by the Group of the European People's Party (Christian Democrats), I would like to draw attention to some of the points raised this evening, which are not dealt with in the resolution, and which are the subject of my amendments., This Parliament has always been an example and a champion in the defence of human rights, and at this critical time it must prove that it is not doing a common cause with a corrupt dictator in full decline and that it will allow itself to be carried away by the collaboration of some Members who have always been manipulated by this dictatorship.]
test_de = "Ich trinke gerne Bier"
learn.blurr_translate(test_de)
[{'translation_texts': 'I like to drink beer'}]
export_fname = "translation_export"

learn.metrics = None
learn = learn.to_fp32()
learn.export(fname=f"{export_fname}.pkl")

inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_generate(test_de)
[{'generated_texts': 'I like to drink beer'}]

Tests

The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained translation models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with … and if any of your pretrained translation models fail, please submit a github issue (or a PR if you’d like to fix it yourself)

try:
    del learn
    torch.cuda.empty_cache()
except:
    pass
[model_type for model_type in NLP.get_models(task="ConditionalGeneration") if (not model_type.startswith("TF"))]
['BartForConditionalGeneration',
 'BigBirdPegasusForConditionalGeneration',
 'BlenderbotForConditionalGeneration',
 'BlenderbotSmallForConditionalGeneration',
 'FSMTForConditionalGeneration',
 'LEDForConditionalGeneration',
 'M2M100ForConditionalGeneration',
 'MBartForConditionalGeneration',
 'MT5ForConditionalGeneration',
 'PLBartForConditionalGeneration',
 'PegasusForConditionalGeneration',
 'ProphetNetForConditionalGeneration',
 'Speech2TextForConditionalGeneration',
 'T5ForConditionalGeneration',
 'XLMProphetNetForConditionalGeneration']
pretrained_model_names = [
    "facebook/bart-base",
    "facebook/wmt19-de-en",  # FSMT
    "Helsinki-NLP/opus-mt-de-en",  # MarianMT
    #'sshleifer/tiny-mbart',
    #'google/mt5-small',
    "t5-small",
]
dataset = load_dataset("wmt16", "de-en", split="train")
dataset = dataset.shuffle(seed=32).select(range(1200))
wmt_df = pd.DataFrame(dataset["translation"], columns=["de", "en"])
len(wmt_df)
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f)
Loading cached shuffled indices for dataset at /home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f/cache-8fc54b133c8c43b7.arrow
1200
model_cls = AutoModelForSeq2SeqLM
bsz = 2
inp_seq_sz = 128
trg_seq_sz = 128

test_results = []
for model_name in pretrained_model_names:
    error = None

    print(f"=== {model_name} ===\n")

    hf_tok_kwargs = {}
    if model_name == "sshleifer/tiny-mbart":
        hf_tok_kwargs["src_lang"], hf_tok_kwargs["tgt_lang"] = "de_DE", "en_XX"

    hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(model_name, model_cls=model_cls, tokenizer_kwargs=hf_tok_kwargs)

    print(f"architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\nmodel:\t\t{type(hf_model).__name__}\n")

    # 1. build your DataBlock
    text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task="translation")

    def add_t5_prefix(inp):
        return f"translate German to English: {inp}" if (hf_arch == "t5") else inp

    batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
        hf_arch,
        hf_config,
        hf_tokenizer,
        hf_model,
        padding="max_length",
        max_length=inp_seq_sz,
        max_target_length=trg_seq_sz,
        text_gen_kwargs=text_gen_kwargs,
    )

    blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)
    dblock = DataBlock(blocks=blocks, get_x=Pipeline([ColReader("de"), add_t5_prefix]), get_y=ColReader("en"), splitter=RandomSplitter())

    dls = dblock.dataloaders(wmt_df, bs=bsz)
    b = dls.one_batch()

    # 2. build your Learner
    seq2seq_metrics = {}

    model = BaseModelWrapper(hf_model)
    fit_cbs = [ShortEpochCallback(0.05, short_valid=True), Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

    learn = Learner(
        dls,
        model,
        opt_func=ranger,
        loss_func=PreCalculatedCrossEntropyLoss(),
        cbs=[BaseModelCallback],
        splitter=partial(blurr_seq2seq_splitter, arch=hf_arch),
    ).to_fp16()

    learn.create_opt()
    learn.freeze()

    # 3. Run your tests
    try:
        print("*** TESTING DataLoaders ***\n")
        test_eq(len(b), 2)
        test_eq(len(b[0]["input_ids"]), bsz)
        test_eq(b[0]["input_ids"].shape, torch.Size([bsz, inp_seq_sz]))
        test_eq(len(b[1]), bsz)

        #         print('*** TESTING One pass through the model ***')
        #         preds = learn.model(b[0])
        #         test_eq(preds[1].shape[0], bsz)
        #         test_eq(preds[1].shape[2], hf_config.vocab_size)

        print("*** TESTING Training/Results ***")
        learn.fit_one_cycle(1, lr_max=1e-3, cbs=fit_cbs)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, "PASSED", ""))
        learn.show_results(learner=learn, max_n=2, input_trunc_at=500, target_trunc_at=250)
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, "FAILED", err))
    finally:
        # cleanup
        del learn
        torch.cuda.empty_cache()
arch tokenizer model_name result error
0 bart BartTokenizerFast BartForConditionalGeneration PASSED
1 fsmt FSMTTokenizer FSMTForConditionalGeneration PASSED
2 marian MarianTokenizer MarianMTModel PASSED
3 t5 T5TokenizerFast T5ForConditionalGeneration PASSED