Transformers and its Trainer
The Trainer
class provides an API for feature-complete training in PyTorch, and it supports distributed training on multiple GPUs/TPUs, mixed precision for NVIDIA GPUs, AMD GPUs, and torch.amp
for PyTorch. Trainer
goes hand-in-hand with the TrainingArguments
class, which offers a wide range of options to customize how a model is trained. Together, these two classes provide a complete training API.
Two other classes Seq2SeqTrainer
and Seq2SeqTrainingArguments
inherit from the Trainer
and TrainingArguments
classes. They are adapted for sequence-to-sequence taskd such as summarization or translation.
Parameters
transformer.Trainer
class has the following arguments.
- model
- it can be a pre trained model downloaded from huggingface’s repository or atorch.nn.Module
.
- if the model to train evaluate or use for predictions is not provided, amodel_init
parameter must be passed like so,
trainer = transformers.Trainer(model_init=function_to_instantiate_a_model())
- args
- an instantiatedTrainingArguments
class is its inputs.
- if not given then the default args will default to a basic instance ofTraininArguments
with output directory set to tmp_trainer in the correct directory (if not provided). - data_collator
- the function to use to form a batch from a list of elements of train_dataset or eval_dataset.
- will default todefault_data_collator()
if no tokenizer is provided.
- otherwise DataCollatorWithPadding is the default. - train_dataset
- the dataset used for training.
- If it is a Dataset (as from the huggingface datasets library) then the columns not accepted by themodel.forward()
method is automatically removed.
- if atorch.utils.data.IterableDataset
with some randomization is the dataset object passed and training is to be done in a distributed fashion, the iterable dataset should either use an internal attributegenerator
that is atorch.Generator
for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this generator at each epoch) or have aset_epoch()
method that internally sets the seed of the Random Number Generators used. - eval_dataset
- this can be atorch.utils.data.Dataset
or a Dictionary.
- If it is a Dataset, columns not accepted by themodel.forward()
method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name. - tokenizer (look here for more details on tokenizers)
- The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. - model_init
- a callable function that instantiates the model to be trained.
- The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc).
- If provided, each call totrain()
will start from a new instance of the model as given by this function. - compute_metrics
- the function that will be used to compute metrics at evaluation.
- it must take an EvalPrediction** and return a dictionary string to metric values that returns evaluation output (always contains labels), to be used to compute metrics.
- **EvalPrediction is a utility class for transformer’strainer
that takes model’s predictions and the targets to be matched.
- Note: When passing TrainingArgs withbatch_eval_metrics
set toTrue
, your compute_metrics function must take a booleancompute_result
argument. This will be triggered after the last eval batch to signal that the function needs to calculate and return the global summary statistics rather than accumulating the batch-level statistics. - callbacks
- A list of **callbacks to customize the training loop. Will add those to the list of default callbacks.
- If you want to remove one of the default callbacks used, use the Trainer.remove_callback() method.
- **callbacks are a class for objects that will inspect the state of the training loop at some events and take some decisions. (refer) - optimizers
- a tuple containing the optimizer and scheduler to use.
- will default to an instance of AdamW on your model and a scheduler given byget_linear_schedule_with_warmup()
controlled byargs
. - preprocess_logits_for_metrics
- a function that preprocess the logits right before caching them at each evaluation step.
- it must take two tensors, the logits and the labels, and return the logits once processed as desired.
- the modifications made by this function will be reflected in the predictions recieved by compute_metrics.
- Note that the labels (second parameter) will beNone
if the dataset does not have them.
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for Huggingface Transformers.