This module contains custom models, custom splitters, etc... translation tasks.
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
 
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti

Translation

Translation tasks attempt to convert text in one language into another

Prepare the data

ds = load_dataset('wmt16', 'de-en', split='train[:1%]')
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/7b2c4443a7d34c2e13df267eaa8cab4c62dd82f6b62b0d9ecc2e3a673ce17308)
path = Path('./')
wmt_df = pd.DataFrame(ds['translation'], columns=['de', 'en']); len(wmt_df)
45489
wmt_df = wmt_df.iloc[:1000]
wmt_df.head(2)
de en
0 Wiederaufnahme der Sitzungsperiode Resumption of the session
1 Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten. I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.
pretrained_model_name = "facebook/bart-large-cnn"
task = HF_TASKS_AUTO.Seq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, task=task)

hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
('bart',
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
blocks = (HF_Seq2SeqBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader('de'), get_y=ColReader('en'), splitter=RandomSplitter())
dls = dblock.dataloaders(wmt_df, bs=2)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([2, 325]), torch.Size([2, 86]))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
text target
0 Angesichts dieser Situation muß aus dem Bericht, den das Parlament annimmt, klar hervorgehen, daß Maßnahmen notwendig sind, die eindeutig auf die Bekämpfung der relativen Armut und der Arbeitslosigkeit gerichtet sind. Maßnahmen wie die für diese Zwe Given this situation, the report approved by Parliament must highlight the need for measures that aim unequivocally to fight relative poverty and unemployment: measures such as the appropriate use of structural funds for these purposes, which are of
1 Es muß daran erinnert werden, daß die globale Wettbewerbsfähigkeit der Europäischen Union gegenwärtig 81 % des Niveaus der Vereinigten Staaten von Amerika erreicht und daß diese Kennziffer sich nur dann verbessern wird, wenn sich die unserer wettbew It must be remembered that, currently, the European Union' s overall competitiveness is, in general terms, 81% of that of the United States of America and that this figure will only improve if the figure for our competitive units, that is the region

Train model

seq2seq_metrics = {
    'bleu': { 'returns': "bleu" },
    'meteor': { 'returns': "meteor" },
    'sacrebleu': { 'returns': "score" }
}

model = HF_BaseModelWrapper(hf_model)

learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()

learn.create_opt() 
learn.freeze()
 
# preds = learn.model(b[0])

# len(preds),preds['loss'].shape, preds['logits'].shape
len(b), len(b[0]), b[0]['input_ids'].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 325]), 2, torch.Size([2, 86]))
print(len(learn.opt.param_groups))
3
learn.lr_find(suggestions=True)
SuggestedLRs(lr_min=8.317637839354575e-05, lr_steep=0.00015848931798245758)
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss bleu meteor sacrebleu time
0 2.287987 2.100064 0.088155 0.290024 8.252259 02:28
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=500)
text target prediction
0 Aus diesem Grund ist es eines der wichtigsten und weitreichendsten Ziele, die wir uns in der Europäischen Union stellen sollten, Anstrengungen zur Schaffung neuer Arbeitsplätze in den ländlichen Gebieten außerhalb des Agrarsektors zu unternehmen, unter anderem in den Bereichen ländlicher Tourismus, Sport, Kultur, Sanierung der Ressourcen, Umstellung von Unternehmen, neue Technologien, Dienstleistungen usw. Doch obwohl die Landwirtschaft keine ausschließliche Rolle mehr spielt, ist sie weiterhin For this reason, one of the most important and essential objectives which we should set in the European Union is to make efforts to create new jobs in rural areas, outside of the agricultural sector, in sectors such as rural tourism, sport, culture, heritage conservation, the conversion of businesses, new technologies, services, etc. However, even though the role of agriculture is not exclusive, it is still essential, not only to prevent economic and social disintegration and the creation of gh Despite the fact that the Landwirtschaft has no significant role to play, it is still very important, not only to ensure the economic and social cohesion of the regions, but also to ensure that the Territoriums have a significant role in the development of the territorial environment and the protection of the environment. It is one of the most important objectives of the European Union, which is to ensure an increase in employment in the rural areas, particularly in tourism, sport, culture, agr
1 Herr Präsident, Herr Kommissar, sehr geehrte Kollegen! Als Beweis dafür, daß dieses Parlament die Phase einer beratenden und untergeordneten Institution noch nicht überwunden hat, konnte der ausgezeichnete Bericht meiner Fraktionskollegin Elisabeth Schroedter nicht ins Plenum gelangen, weil die Regionalentwicklungspläne des Zeitraums 2000-2006 für die Ziel-1-Regionen bereits mehrere Monate in den Arbeitszimmern der Kommission liegen. Mr President, Commissioner, as proof that this Parliament has not yet overcome its role as a consultative and subordinate institution, the excellent report by a fellow member of my Group, Elisabeth Schroedter, has not been able to reach plenary sitting because the plans for regional development for the period 2000-2006 for Objective 1 regions have been sitting in the Commission' s offices for several months. Mr President, Commissioner, I would like to congratulate you on a very good report by your rapporteur, Elisabeth Schroedter, which did not make it to the plenary session because it had already been discussed for several months in the Commission' s offices.
test_de = "Ich trinke gerne Bier"
outputs = learn.blurr_generate(test_de, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 I am a beer drinker, and I like to drink a beer, as you know. It is a very good thing to be a beer tipple, in a sense. It's a very beautiful thing. I like it. I think it is a good thing.

=== Prediction 2 ===
 I am a beer drinker, and I like to drink a beer, as you know. It is a very good thing to be a beer tipple, in a sense. It's a very beautiful thing. I like it. I think it is a good idea.

=== Prediction 3 ===
 I am a beer drinker, and I like to drink a beer, as you know. It is a very good thing to be a beer tipple, in a sense. It's a very beautiful thing. I like it. I think it is a good product.

Inference

export_fname = 'translation_export'
learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')
inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_de)
[" I am a beer drinker, and I like to drink a beer, as you know. It is a very good thing to be a beer tipple, in a sense. It's a very beautiful thing. I like it. I think it is a good thing."]

Tests

The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

try: del learn; torch.cuda.empty_cache()
except: pass
[ model_type for model_type in BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration') 
 if (not model_type.__name__.startswith('TF')) ]
[transformers.models.bart.modeling_bart.BartForConditionalGeneration,
 transformers.models.blenderbot.modeling_blenderbot.BlenderbotForConditionalGeneration,
 transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallForConditionalGeneration,
 transformers.models.fsmt.modeling_fsmt.FSMTForConditionalGeneration,
 transformers.models.led.modeling_led.LEDForConditionalGeneration,
 transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration,
 transformers.models.mt5.modeling_mt5.MT5ForConditionalGeneration,
 transformers.models.pegasus.modeling_pegasus.PegasusForConditionalGeneration,
 transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration,
 transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
 transformers.models.xlm_prophetnet.modeling_xlm_prophetnet.XLMProphetNetForConditionalGeneration]
pretrained_model_names = [
    'facebook/bart-base',
    'facebook/wmt19-de-en',                      # FSMT
    'Helsinki-NLP/opus-mt-de-en',                # MarianMT
    #'sshleifer/tiny-mbart',
    #'google/mt5-small',
    't5-small'
]
path = Path('./')
ds = load_dataset('wmt16', 'de-en', split='train[:1%]')
wmt_df = pd.DataFrame(ds['translation'], columns=['de', 'en']); len(wmt_df)
wmt_df = wmt_df.iloc[:1000]
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/7b2c4443a7d34c2e13df267eaa8cab4c62dd82f6b62b0d9ecc2e3a673ce17308)
#hide_output
task = HF_TASKS_AUTO.Seq2SeqLM
bsz = 2
inp_seq_sz = 128; trg_seq_sz = 128

test_results = []
for model_name in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, task=task)
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\nmodel:\t\t{type(hf_model).__name__}\n')
    
    # 1. build your DataBlock
    text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='translation')
    
    tok_kwargs = {}
    if (hf_arch == 'mbart'):
        tok_kwargs['src_lang'], tok_kwargs['tgt_lang'] = "de_DE", "en_XX"
            
    def add_t5_prefix(inp): return f'translate German to English: {inp}' if (hf_arch == 't5') else inp
    
    before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                      padding='max_length', 
                                                      max_length=inp_seq_sz, 
                                                      max_target_length=trg_seq_sz, 
                                                      tok_kwargs=tok_kwargs, text_gen_kwargs=text_gen_kwargs)
    
    blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)
    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('de'), add_t5_prefix]), 
                   get_y=ColReader('en'), 
                   splitter=RandomSplitter())

    dls = dblock.dataloaders(wmt_df, bs=bsz) 
    b = dls.one_batch()

    # 2. build your Learner
    seq2seq_metrics = {}
    
    model = HF_BaseModelWrapper(hf_model)
    fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

    learn = Learner(dls, 
                    model,
                    opt_func=ranger,
                    loss_func=HF_PreCalculatedLoss(),
                    cbs=[HF_BaseModelCallback],
                    splitter=partial(seq2seq_splitter, arch=hf_arch)).to_fp16()

    learn.create_opt() 
    learn.freeze()
    
    # 3. Run your tests
    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, inp_seq_sz]))
        test_eq(len(b[1]), bsz)

#         print('*** TESTING One pass through the model ***')
#         preds = learn.model(b[0])
#         test_eq(preds[1].shape[0], bsz)
#         test_eq(preds[1].shape[2], hf_config.vocab_size)

        print('*** TESTING Training/Results ***')
        learn.fit_one_cycle(1, lr_max=1e-3, cbs=fit_cbs)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'PASSED', ''))
        learn.show_results(learner=learn, max_n=2, input_trunc_at=500, target_trunc_at=250)
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'FAILED', err))
    finally:
        # cleanup
        del learn; torch.cuda.empty_cache()
arch tokenizer model_name result error
0 bart BartTokenizerFast BartForConditionalGeneration PASSED
1 fsmt FSMTTokenizer FSMTForConditionalGeneration PASSED
2 marian MarianTokenizer MarianMTModel PASSED
3 t5 T5TokenizerFast T5ForConditionalGeneration PASSED

Cleanup