generated by bing create.

Transformers and its Trainer

deepblue research
4 min readJun 3, 2024

--

The Trainer class provides an API for feature-complete training in PyTorch, and it supports distributed training on multiple GPUs/TPUs, mixed precision for NVIDIA GPUs, AMD GPUs, and torch.amp for PyTorch. Trainer goes hand-in-hand with the TrainingArguments class, which offers a wide range of options to customize how a model is trained. Together, these two classes provide a complete training API.

Two other classes Seq2SeqTrainer and Seq2SeqTrainingArguments inherit from the Trainer and TrainingArguments classes. They are adapted for sequence-to-sequence taskd such as summarization or translation.

Parameters

transformer.Trainer class has the following arguments.

  • model
    - it can be a pre trained model downloaded from huggingface’s repository or a torch.nn.Module .
    - if the model to train evaluate or use for predictions is not provided, a model_init parameter must be passed like so,
trainer = transformers.Trainer(model_init=function_to_instantiate_a_model())
  • args
    - an instantiated TrainingArguments class is its inputs.
    - if not given then the default args will default to a basic instance of TraininArguments with output directory set to tmp_trainer in the correct directory (if not provided).
  • data_collator
    - the function to use to form a batch from a list of elements of train_dataset or eval_dataset.
    - will default to default_data_collator() if no tokenizer is provided.
    - otherwise DataCollatorWithPadding is the default.
  • train_dataset
    - the dataset used for training.
    - If it is a Dataset (as from the huggingface datasets library) then the columns not accepted by the model.forward() method is automatically removed.
    - if a torch.utils.data.IterableDataset with some randomization is the dataset object passed and training is to be done in a distributed fashion, the iterable dataset should either use an internal attribute generator that is a torch.Generator for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this generator at each epoch) or have a set_epoch() method that internally sets the seed of the Random Number Generators used.
  • eval_dataset
    -
    this can be a torch.utils.data.Dataset or a Dictionary.
    - If it is a Dataset, columns not accepted by the model.forward() method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name.
  • tokenizer (look here for more details on tokenizers)
    - The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model.
  • model_init
    - a callable function that instantiates the model to be trained.
    - The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc).
    - If provided, each call to train() will start from a new instance of the model as given by this function.
  • compute_metrics
    - the function that will be used to compute metrics at evaluation.
    - it must take an EvalPrediction** and return a dictionary string to metric values that returns evaluation output (always contains labels), to be used to compute metrics.
    - **EvalPrediction is a utility class for transformer’s trainer that takes model’s predictions and the targets to be matched.
    - Note: When passing TrainingArgs with batch_eval_metrics set to True, your compute_metrics function must take a boolean compute_result argument. This will be triggered after the last eval batch to signal that the function needs to calculate and return the global summary statistics rather than accumulating the batch-level statistics.
  • callbacks
    -
    A list of **callbacks to customize the training loop. Will add those to the list of default callbacks.
    - If you want to remove one of the default callbacks used, use the Trainer.remove_callback() method.
    - **callbacks are a class for objects that will inspect the state of the training loop at some events and take some decisions. (refer)
  • optimizers
    -
    a tuple containing the optimizer and scheduler to use.
    - will default to an instance of AdamW on your model and a scheduler given by get_linear_schedule_with_warmup() controlled by args.
  • preprocess_logits_for_metrics
    - a function that preprocess the logits right before caching them at each evaluation step.
    - it must take two tensors, the logits and the labels, and return the logits once processed as desired.
    - the modifications made by this function will be reflected in the predictions recieved by compute_metrics.
    -
    Note that the labels (second parameter) will be None if the dataset does not have them.

Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for Huggingface Transformers.

--

--