Fine-Tuning in Machine Learning

Introduction

Machine learning is one of the trending concepts in the modern world. We are training and developing new models day-by-day so, ensuring and maintaining the accuracy of model response is the responsibility of the developers.

Understanding Fine-tuning

It is one of the farms of transfer learning in such learning a pre-trained model is reused as the initial stage of a model on a new task. In fine-tuning we make minor adjustments to the parameters of pre-trained models to tailor it to a specific task. Fine tuning is performed on the knowledge of a pre-trained model that it acquires ready and makes it more efficient than training from scratch.

Working of Fine-tuning

  1. Selection of Pre-Trained Model: This is the first step for the fine-tuning process where we select a pre-trained model which is most relevant to our desired task. People generally  preferred GPT, T5 for NLP and ResNet pre-trained models.
  2. Replacing the Final Layer: The final layer of the pre-trained model, which is generally trained for different work, is now replaced by a new layer that is suitable for our new task. Rarely do we find that the final layer of the model is relevant to our new task in such a condition we ignore this step.
  3. Adjust the Model: In this step, we start a continuous training of pre-trained on the new data set until it becomes fine-tuned.
  4. Freezing Early Layers: This step involves freezing the earlier layers to prevent updates in them while training.
  5. Training Later layers: In these steps we train the later layer (i.e those layers which have more specific features) on the new dataset.
  6. Fine-tuning the Entire Model: We fine-tuned the entire model for better adaptation of new tasks, it was required in only some of the cases.

Implementation 

Fine-Tuning BERT for Text Classification

In this example, we'll fine-tune the BERT model from Hugging Face's Transformers library for a text classification task.

1. Install Required Libraries

pip install transformers torch datasets

2. Import Libraries

import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

3. Load a Pre-trained Model and Tokenizer

# Load pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

4. Prepare the Dataset

Let's use the datasets library to load a sample dataset for binary classification.

# Load dataset
dataset = load_dataset('imdb')

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Format the dataset for PyTorch
tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

5. Define Training Arguments

training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)

6. Initialize the Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test']
)

7. Train the Model

trainer.train()

8. Evaluate the Model

results = trainer.evaluate()
print(results)

Explanation

  • Loading the Model and Tokenizer: We use BertTokenizer to tokenize text and BertForSequenceClassification for the classification task. The model is initialized with pre-trained weights and modified for binary classification.
  • Preparing the Dataset: We use the datasets library to load and tokenize the dataset. Tokenization converts text into input IDs and attention masks that the model can process.
  • Training Arguments: We specify parameters like learning rate, batch size, and number of epochs. These control the training process.
  • Training and Evaluation: The Trainer class simplifies the training and evaluation process. It handles the training loop and evaluation based on the specified arguments.

Conclusion

This is one of the great techniques in machine learning, because it reused the pre-trained model and made it efficient for a new task with good accuracy. It also reduced the work of training a new model.


Similar Articles