This module contains custom models, custom splitters, etc... translation tasks.
 
What we're running with at the time this documentation was generated:
torch: 1.9.0+cu102
fastai: 2.5.2
transformers: 4.10.0
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti

Translation

Translation tasks attempt to convert text in one language into another

Prepare the data

ds = load_dataset('wmt16', 'de-en', split='train[:1%]')
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a)
path = Path('./')
wmt_df = pd.DataFrame(ds['translation'], columns=['de', 'en']); len(wmt_df)
45489
wmt_df = wmt_df.iloc[:1000]
wmt_df.head(2)
de en
0 Wiederaufnahme der Sitzungsperiode Resumption of the session
1 Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten. I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.
pretrained_model_name = "Helsinki-NLP/opus-mt-de-en"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=model_cls)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
('marian',
 transformers.models.marian.tokenization_marian.MarianTokenizer,
 transformers.models.marian.configuration_marian.MarianConfig,
 transformers.models.marian.modeling_marian.MarianMTModel)
blocks = (HF_Seq2SeqBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader('de'), get_y=ColReader('en'), splitter=RandomSplitter())
dls = dblock.dataloaders(wmt_df, bs=2)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([2, 141]), torch.Size([2, 83]))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
text target
0 ▁Angesichts▁dieser Situation▁muß▁aus dem▁Bericht, den das▁Parlament annimmt,▁klar▁hervorgehen,▁daß▁Maßnahmen▁notwendig▁sind, die▁eindeutig auf die▁Bekämpfung der relativen▁Armut und der Arbeitslosigkeit▁gerichtet▁sind.▁Maßnahmen▁wie die für diese▁Zwe Given this situation, the report approved by Parliament must highlight the need for measures that aim unequivocally to fight relative poverty and unemployment: measures such as the appropriate use of structural funds for these purposes, which are oft
1 ▁Ich▁kann▁jetzt nicht▁alle▁nennen:▁Einführung von▁sektorübergreifenden▁Maßnahmen,▁effizientere▁Nutzung▁öffentlicher▁Gelder,▁Unterstützung der▁unterschiedlichen Partner▁bei der▁gemeinsamen▁Entwicklung von▁regionalen oder▁nationalen▁Programmen▁usw. Die I shall not list them all, but they include implementing intersectoral policies, increasing efficiency in the use of public funds, assisting the various partners in drawing up regional or national programming together, etc. The Commission takes note

Train model

seq2seq_metrics = {
    'bleu': { 'returns': "bleu" },
    'meteor': { 'returns': "meteor" },
    'sacrebleu': { 'returns': "score" }
}

model = HF_BaseModelWrapper(hf_model)

learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()

learn.freeze()
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
learn.summary()
# preds = learn.model(b[0])

# len(preds),preds['loss'].shape, preds['logits'].shape
len(b), len(b[0]), b[0]['input_ids'].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 141]), 2, torch.Size([2, 83]))
print(len(learn.opt.param_groups))
3
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
/home/wgilliam/miniconda3/envs/blurr/lib/python3.9/site-packages/fastai/callback/schedule.py:269: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "ro" (-> color='r'). The keyword argument will take precedence.
  ax.plot(val, idx, 'ro', label=nm, c=color)
SuggestedLRs(minimum=0.00012022644514217973, steep=9.12010818865383e-07, valley=9.120108734350652e-05, slide=4.365158383734524e-05)
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss bleu meteor sacrebleu time
0 1.292981 1.263649 0.320008 0.540354 31.188232 00:59
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=500)
text target prediction
0 Aus▁diesem Grund▁ist es▁eines der▁wichtigsten und▁weitreichendsten▁Ziele, die wir▁uns in der▁Europäischen Union▁stellen▁sollten,▁Anstrengungen zur▁Schaffung▁neuer▁Arbeitsplätze in den▁ländlichen▁Gebieten▁außerhalb des▁Agrarsektors zu▁unternehmen,▁unter▁anderem in den▁Bereichen▁ländlicher▁Tourismus, Sport, Kultur,▁Sanierung der▁Ressourcen,▁Umstellung von▁Unternehmen,▁neue▁Technologien,▁Dienstleistungen▁usw.▁Doch▁obwohl die▁Landwirtschaft▁keine▁ausschließliche▁Rolle▁mehr▁spielt,▁ist▁sie▁weiterhin▁ For this reason, one of the most important and essential objectives which we should set in the European Union is to make efforts to create new jobs in rural areas, outside of the agricultural sector, in sectors such as rural tourism, sport, culture, heritage conservation, the conversion of businesses, new technologies, services, etc. However, even though the role of agriculture is not exclusive, it is still essential, not only to prevent economic and social disintegration and the creation of gho For this reason, it is one of the most important and far-reaching objectives that we should set ourselves in the European Union to make efforts to create new jobs in rural areas outside the agricultural sector, including in the areas of rural tourism, sport, culture, rehabilitation of resources, conversion of businesses, new technologies, services, etc. However, although agriculture no longer plays an exclusive role, it is still important, not only to prevent the economic and social decline of r
1 Es▁muß▁daran▁erinnert▁werden,▁daß die▁globale▁Wettbewerbsfähigkeit der▁Europäischen Union▁gegenwärtig 81 % des▁Niveaus der▁Vereinigten▁Staaten von▁Amerika▁erreicht und▁daß diese▁Kennziffer sich▁nur▁dann▁verbessern▁wird,▁wenn sich die▁unserer wettbewerbsfähigen▁Wirtschaftseinheiten,▁nämlich der▁Regionen,▁verbessert, und das zu▁einem▁Zeitpunkt, da die▁technologische▁Entwicklung, die▁Globalisierung der Wirtschaft und▁unsere▁Probleme, die▁Erweiterung und die▁Einheitswährung, von den▁Regionen,▁aber▁a It must be remembered that, currently, the European Union' s overall competitiveness is, in general terms, 81% of that of the United States of America and that this figure will only improve if the figure for our competitive units, that is the regions, also improves. Furthermore, this is at a time when technological development, economic globalisation and our problems, which are enlargement and the single currency, demand that the regions, as well as businesses and individuals, make more of an ef It should be recalled that the European Union' s overall competitiveness is currently reaching 81% of the level of the United States of America and that this figure will only improve if our competitive economic units, namely the regions, improve, at a time when technological development, the globalisation of the economy and our problems, enlargement and the single currency, demand greater competition from the regions, but also from businesses and individuals.
test_de = "Ich trinke gerne Bier"
outputs = learn.blurr_generate(test_de, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
I like to drink beer

=== Prediction 2 ===
I like to drink beer.

=== Prediction 3 ===
I like drinking beer

Inference

export_fname = 'translation_export'
learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')
inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_de)
['I like to drink beer']

High-level API

BlearnerForTranslation

We also introduce a task specific Blearner that get you your DataBlock, DataLoaders, and BLearner in one line of code!

class BlearnerForTranslation[source]

BlearnerForTranslation(dls:DataLoaders, hf_model:PreTrainedModel, base_model_cb:HF_BaseModelCallback=HF_BaseModelCallback, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Blearner

Group together a model, some dls and a loss_func to handle training

Parameters:

  • dls : <class 'fastai.data.core.DataLoaders'>

  • hf_model : <class 'transformers.modeling_utils.PreTrainedModel'>

  • kwargs : <class 'inspect._empty'>

learn = BlearnerForTranslation.from_dataframe(wmt_df, 'Helsinki-NLP/opus-mt-de-en', 
                                              src_lang_name='German', src_lang_attr='de', 
                                              trg_lang_name='English', trg_lang_attr='en', 
                                              dblock_splitter=RandomSplitter(),
                                              dl_kwargs={'bs':2})
metrics_cb = BlearnerForTranslation.get_metrics_cb()
learn.fit_one_cycle(1, lr_max=4e-5, cbs=[metrics_cb])
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
epoch train_loss valid_loss bleu meteor sacrebleu time
0 1.326087 1.324222 0.308548 0.511621 30.511772 00:57
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 ▁Angesichts▁dieser Situation▁muß▁aus dem▁Bericht, den das▁Parlament annimmt,▁klar▁hervorgehen,▁daß▁Maßnahmen▁notwendig▁sind, die▁eindeutig auf die▁Bekämpfung der relativen▁Armut und der Arbeitslosigkeit▁gerichtet▁sind.▁Maßnahmen▁wie die für diese▁Zwecke▁angemessene▁Verwendung der▁Strukturfonds, die▁häufig▁unsachgemäß▁eingesetzt▁werden, und▁zwar mit▁zentralen▁staatlichen▁Politiken, die▁Modernisierung der▁Bereiche Telekommunikation und▁Kommunikation,▁indem man vor▁allem die am▁wenigsten▁entwickelt Given this situation, the report approved by Parliament must highlight the need for measures that aim unequivocally to fight relative poverty and unemployment: measures such as the appropriate use of structural funds for these purposes, which are oft In view of this situation, the report adopted by Parliament must clearly show the need for measures which are clearly aimed at combating relative poverty and unemployment, such as the use of the Structural Funds, which are often used improperly for t
1 ▁Deshalb▁besteht der▁Vorschlag der▁Fraktion der▁Sozialdemokratischen▁Partei▁Europas, den Sie▁erwähnt▁haben,▁darin, den▁Mittwoch▁als▁Termin der▁Vorstellung des▁Programms der▁Kommission Prodi für die▁Wahlperiode▁beizubehalten, und in▁dieses▁Programm▁auch das▁Verwaltungsreformprojekt▁einzubeziehen, da wir▁andernfalls in eine paradoxe Situation▁geraten▁könnten: Mit der Ausrede, der▁Wortlaut liege nicht vor,▁wird▁einerseits dem▁Präsidenten der▁Kommission das▁Recht▁abgesprochen, in▁diesem▁Parlament zu Therefore, the proposal of the Group of the Party of European Socialists, and which you have mentioned, is that the Prodi Commission present its legislative programme on Wednesday, including its proposed administrative reform, because, otherwise, we That is why the proposal of the Group of the Party of European Socialists, which you have mentioned, is to maintain Wednesday as the date for the presentation of the Prodi Commission' s programme for the parliamentary term, and to include in this pro
test_de = "Ich trinke gerne Bier"
learn.blurr_generate(test_de)
['I like to drink beer']
export_fname = 'translation_export'

learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')

inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_de)
['I like to drink beer']

Tests

The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained translation models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained translation models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

try: del learn; torch.cuda.empty_cache()
except: pass
[ model_type for model_type in BLURR.get_models(task='ConditionalGeneration') 
 if (not model_type.startswith('TF')) ]
['BartForConditionalGeneration',
 'BigBirdPegasusForConditionalGeneration',
 'BlenderbotForConditionalGeneration',
 'BlenderbotSmallForConditionalGeneration',
 'FSMTForConditionalGeneration',
 'LEDForConditionalGeneration',
 'M2M100ForConditionalGeneration',
 'MBartForConditionalGeneration',
 'MT5ForConditionalGeneration',
 'PegasusForConditionalGeneration',
 'ProphetNetForConditionalGeneration',
 'Speech2TextForConditionalGeneration',
 'T5ForConditionalGeneration',
 'XLMProphetNetForConditionalGeneration']
pretrained_model_names = [
    'facebook/bart-base',
    'facebook/wmt19-de-en',                      # FSMT
    'Helsinki-NLP/opus-mt-de-en',                # MarianMT
    #'sshleifer/tiny-mbart',
    #'google/mt5-small',
    't5-small'
]
path = Path('./')
ds = load_dataset('wmt16', 'de-en', split='train[:1%]')
wmt_df = pd.DataFrame(ds['translation'], columns=['de', 'en']); len(wmt_df)
wmt_df = wmt_df.iloc[:1000]
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a)
#hide_output
model_cls = AutoModelForSeq2SeqLM
bsz = 2
inp_seq_sz = 128; trg_seq_sz = 128

test_results = []
for model_name in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_tok_kwargs = {}
    if (model_name == 'sshleifer/tiny-mbart'):
        hf_tok_kwargs['src_lang'], hf_tok_kwargs['tgt_lang'] = "de_DE", "en_XX"
            
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(model_name, 
                                                                      model_cls=model_cls, 
                                                                      tokenizer_kwargs=hf_tok_kwargs)
    
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\nmodel:\t\t{type(hf_model).__name__}\n')
    
    # 1. build your DataBlock
    text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='translation')
    
    def add_t5_prefix(inp): return f'translate German to English: {inp}' if (hf_arch == 't5') else inp
    
    before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                      padding='max_length', 
                                                      max_length=inp_seq_sz, 
                                                      max_target_length=trg_seq_sz, 
                                                      text_gen_kwargs=text_gen_kwargs)
    
    blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)
    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('de'), add_t5_prefix]), 
                   get_y=ColReader('en'), 
                   splitter=RandomSplitter())

    dls = dblock.dataloaders(wmt_df, bs=bsz) 
    b = dls.one_batch()

    # 2. build your Learner
    seq2seq_metrics = {}
    
    model = HF_BaseModelWrapper(hf_model)
    fit_cbs = [
        ShortEpochCallback(0.05, short_valid=True), 
        HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)
    ]

    learn = Learner(dls, 
                    model,
                    opt_func=ranger,
                    loss_func=HF_PreCalculatedLoss(),
                    cbs=[HF_BaseModelCallback],
                    splitter=partial(seq2seq_splitter, arch=hf_arch)).to_fp16()

    learn.create_opt() 
    learn.freeze()
    
    # 3. Run your tests
    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, inp_seq_sz]))
        test_eq(len(b[1]), bsz)

#         print('*** TESTING One pass through the model ***')
#         preds = learn.model(b[0])
#         test_eq(preds[1].shape[0], bsz)
#         test_eq(preds[1].shape[2], hf_config.vocab_size)

        print('*** TESTING Training/Results ***')
        learn.fit_one_cycle(1, lr_max=1e-3, cbs=fit_cbs)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'PASSED', ''))
        learn.show_results(learner=learn, max_n=2, input_trunc_at=500, target_trunc_at=250)
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'FAILED', err))
    finally:
        # cleanup
        del learn; torch.cuda.empty_cache()
arch tokenizer model_name result error
0 bart BartTokenizerFast BartForConditionalGeneration PASSED
1 fsmt FSMTTokenizer FSMTForConditionalGeneration PASSED
2 marian MarianTokenizer MarianMTModel PASSED
3 t5 T5TokenizerFast T5ForConditionalGeneration PASSED

Summary

This module includes the fundamental bits to use Blurr for translation tasks training and inference.