OpenAI Whisper Fine-tuning

With Huggingface transformers and transfer learning

Xin Cheng
4 min readDec 12, 2023

We walked through speech recognition inference with openai Whisper model in previous articles 1, 2. How can you use your own dataset to fine-tune the pretrained model?

TLDR: Huggingface transformers framework supports OpenAI Whisper Fine-tuning

Data

Dataset

open-sourced speech recognition datasets, covering multiple languages. The largest of these are:

There are smaller datasets covering many more languages and dialects, such as:

  • VoxPopuli: 1,800 hours, 16 languages
  • Fleurs: 12 hours per language, 102 languages
  • There are also individual datasets hosted by OpenSLR

Components

The ASR pipeline can be de-composed into three components:

  1. A feature extractor which pre-processes the raw audio-inputs
  2. The model which performs the sequence-to-sequence mapping
  3. A tokenizer which post-processes the model outputs to text format

In Transformers, the Whisper model has an associated feature extractor and tokenizer, called WhisperFeatureExtractor and WhisperTokenizer respectively.

Feature extractor

It is crucial that we match the sampling rate of our audio inputs to the sampling rate expected by our model, as audio signals with different sampling rates have very different distributions.

The Whisper feature extractor performs two operations. It first pads/truncates a batch of audio samples such that all samples have an input length of 30s.

The second operation that the Whisper feature extractor performs is converting the padded audio arrays to log-Mel spectrograms.

Training

Use Huggingface transformers training arguments and Seq2SeqTrainer API

Tokenizer

Tokenize text to label (input_ids) and model outputs to text format

Steps

We are using Jupyter notebooks (Google colab):

Login huggingface

# in cli, huggingface-cli login, use api key from https://huggingface.co/login?next=%2Fsettings%2Ftokens

Install dependency

!pip install --upgrade pip
!pip install --upgrade datasets transformers accelerate soundfile librosa evaluate jiwer tensorboard gradio

Load dataset

from datasets import load_dataset, DatasetDict

common_voice = DatasetDict()

common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True)

print(common_voice)

Remove additional metadata and match audio sampling rate to Whisper model

# remove additional metadata information
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
# After this step audio column will have an array variable with 16K rate Audio
from datasets import Audio

common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
print(common_voice["train"][0])

Load Feature extractor and tokenizer

from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")


from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

Prepare data set

  1. We load and resample the audio data by calling batch["audio"]. As explained above, Datasets performs any necessary resampling operations on the fly.
  2. We use the feature extractor to compute the log-Mel spectrogram input features from our 1-dimensional audio array.
  3. We encode the transcriptions to label ids through the use of the tokenizer.
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]

# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch

common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)

Define Processor

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

# put together a list of samples into a mini training batch, https://www.youtube.com/watch?v=-RPeakdlHYo
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any

def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]

batch["labels"] = labels

return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

Define WER evaluation

import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids

# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id

# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

wer = 100 * metric.compute(predictions=pred_str, references=label_str)

return {"wer": wer}

Define training

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-hi", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=4000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)

Start training

trainer.train()

Appendix

--

--

Xin Cheng
Xin Cheng

Written by Xin Cheng

Multi/Hybrid-cloud, Kubernetes, cloud-native, big data, machine learning, IoT developer/architect, 3x Azure-certified, 3x AWS-certified, 2x GCP-certified

Responses (1)