paint-brush
Working With Wav2vec2 Part 1: Finetuning XLS-R for Automatic Speech Recognitionby@pictureinthenoise
2,788 reads
2,788 reads

Working With Wav2vec2 Part 1: Finetuning XLS-R for Automatic Speech Recognition

by Picture in the NoiseMay 4th, 2024
Read on Terminal Reader
Read this story w/o Javascript

Too Long; Didn't Read

This guide explains the steps to finetune Meta AI's wav2vec2 XLS-R model for automatic speech recognition ("ASR"). The guide includes step-by-step instructions on how to build a Kaggle Notebook that can be used to finetune the model. The model is trained on a Chilean Spanish dataset.
featured image - Working With Wav2vec2 Part 1: Finetuning XLS-R for Automatic Speech Recognition
Picture in the Noise HackerNoon profile picture

Introduction

Meta AI introduced wav2vec2 XLS-R ("XLS-R") at the end of 2021. XLS-R is a machine learning ("ML") model for cross-lingual speech representations learning; and it was trained on over 400,000 hours of publicly available speech audio across 128 languages. Upon its release, the model represented a leap over Meta AI's XLSR-53 cross-lingual model which was trained on approximately 50,000 hours of speech audio across 53 languages.


This guide explains the steps to finetune XLS-R for automatic speech recognition ("ASR") using a Kaggle Notebook. The model will be finetuned on Chilean Spanish, but the general steps can be followed to finetune XLS-R on different languages that you desire.


Running inference on the finetuned model will be described in a companion tutorial making this guide the first of two parts. I decided to create a separate inference-specific guide as this finetuning guide became a bit long.


It is assumed you have an existing ML background and that you understand basic ASR concepts. Beginners may have a difficult time following/understanding the build steps.

A Bit of Background on XLS-R

The original wav2vec2 model introduced in 2020 was pretrained on 960 hours of Librispeech dataset speech audio and ~53,200 hours of LibriVox dataset speech audio. Upon its release, two model sizes were available: the BASE model with 95 million parameters and the LARGE model with 317 million parameters.


XLS-R, on the other hand, was pretrained on multilingual speech audio from 5 datasets:


  • VoxPopuli: A total of ~372,000 hours of speech audio across 23 European languages of parliamentary speech from the European parliament.
  • Multilingual Librispeech: A total of ~50,000 hours of speech audio across eight European languages, with the majority (~44,000 hours) of audio data in English.
  • CommonVoice: A total of ~7,000 hours of speech audio across 60 languages.
  • VoxLingua107: A total of ~6,600 hours of speech audio across 107 languages based on YouTube content.
  • BABEL: A total of ~1,100 hours of speech audio across 17 African and Asian languages based on conversational telephone speech.


There are 3 XLS-R models: XLS-R (0.3B) with 300 million parameters, XLS-R (1B) with 1 billion parameters, and XLS-R (2B) with 2 billion parameters. This guide will use the XLS-R (0.3B) model.

Approach

There are some great write-ups on how to finetune wav2vev2 models, with perhaps this one being a "gold standard" of sorts. Of course, the general approach here mimics what you will find in other guides. You will:


  • Load a training dataset of audio data and associated text transcriptions.
  • Create a vocabulary from the text transcriptions in the dataset.
  • Initialize a wav2vec2 processor that will extract features from the input data, as well as convert text transcriptions to sequences of labels.
  • Finetune wav2vec2 XLS-R on the processed input data.


However, there are three key differences between this guide and others:


  1. The guide does not provide as much "inline" discussion on relevant ML and ASR concepts.
    • While each sub-section on individual notebook cells will include details on the use/purpose of the particular cell, it is assumed you have an existing ML background and that you understand basic ASR concepts.
  2. The Kaggle Notebook that you will build organizes utility methods in top-level cells.
    • Whereas many finetuning notebooks tend to have a sort of "stream-of-consciousness"-type layout, I elected to organize all utility methods together. If you're new to wav2vec2, you may find this approach confusing. However, to reiterate, I do my best to be explicit when explaining the purpose of each cell in each cell's dedicated sub-section. If you are just learning about wav2vec2, you might benefit from taking a quick glance at my HackerNoon article wav2vec2 for Automatic Speech Recognition in Plain English.
  3. This guide describes the steps for finetuning only.
    • As mentioned in the Introduction, I opted to create a separate companion guide on how to run inference on the finetuned XLS-R model that you will generate. This was done due to prevent this guide from becoming excessively long.

Prerequisites and Before You Get Started

To complete the guide, you will need to have:


  • An existing Kaggle account. If you don't have an existing Kaggle account, you need to create one.
  • An existing Weights and Biases account ("WandB"). If you don't have an existing Weights and Biases account, you need to create one.
  • A WandB API key. If you don't have a WandB API key, follow the steps here.
  • Intermediate knowledge of Python.
  • Intermediate knowledge of working with Kaggle Notebooks.
  • Intermediate knowledge of ML concepts.
  • Basic knowledge of ASR concepts.


Before you get started with building the notebook, it may be helpful to review the two sub-sections directly below. They describe:


  1. The training dataset.
  2. The Word Error Rate ("WER") metric used during training.

Training Dataset

As mentioned in the Introduction, the XLS-R model will be finetuned on Chilean Spanish. The specific dataset is the Chilean Spanish Speech Data Set developed by Guevara-Rukoz et al. It is available for download on OpenSLR. The dataset consists of two sub-datasets: (1) 2,636 audio recordings of Chilean male speakers and (2) 1,738 audio recordings of Chilean female speakers.


Each sub-dataset includes a line_index.tsv index file. Each line of each index file contains a pair of an audio filename and a transcription of the audio in the associated file, e.g.:


clm_08421_01719502739	Es un viaje de negocios solamente voy por una noche
clm_02436_02011517900	Se usa para incitar a alguien a sacar el mayor provecho del dia presente


I have uploaded the Chilean Spanish Speech Data Set to Kaggle for convenience. There is one Kaggle dataset for the recordings of Chilean male speakers and one Kaggle dataset for the recordings of Chilean female speakers. These Kaggle datasets will be added to the Kaggle Notebook that you will build following the steps in this guide.

Word Error Rate (WER)

WER is one metric that can be used to measure performance of automatic speech recognition models. WER provides a mechanism to measure how close a text prediction is to a text reference. WER accomplishes this by recording errors of 3 types:


  • substitutions (S): A substitution error is recorded when the prediction contains a word that is different from the analogous word in the reference. For example, this occurs when the prediction mis-spells a word in the reference.

  • deletions (D): A deletion error is recorded when the prediction contains a word that is not present in the reference.

  • insertions (I): An insertion error is recorded when the prediction does not contain a word that is present in the reference.


Obviously, WER works at the word-level. The formula for the WER metric is as follows:


WER = (S + D + I)/N

where:
S = number of substition errors
D = number of deletion errors
I = number of insertion errors
N = number of words in the reference


A simple WER example in Spanish is as follows:


prediction: "Él está saliendo."
reference: "Él está saltando."


A table helps to visualize the errors in the prediction:

TEXT

WORD 1

WORD 2

WORD 3

prediction

Él

está

saliendo

reference

Él

está

saltando


correct

correct

substitution

The prediction contains 1 substitution error, 0 deletion errors, and 0 insertion errors. So, the WER for this example is:


WER = 1 + 0 + 0 / 3 = 1/3 = 0.33


It should be obvious that the Word Error Rate does not necessarily tell us what specific errors exist. In the example above, WER identifies that WORD 3 contains an error in the predicted text, but it doesn't tell us that the characters i and e are wrong in the prediction. Other metrics, such as the Character Error Rate ("CER"), can be used for more precise error analysis.

Building the Finetuning Notebook

You are now ready to start building the finetuning notebook.


  • Step 1 and Step 2 guide you through setting up your Kaggle Notebook environment.
  • Step 3 guides you through building the notebook itself. It contains 32 sub-steps representing the 32 cells of the finetuning notebook.
  • Step 4 guides you through running the notebook, monitoring training, and saving the model.

Step 1 - Fetch Your WandB API Key

Your Kaggle Notebook must be configured to send training run data to WandB using your WandB API key. In order to do that, you need to copy it.


  1. Log in to WandB at www.wandb.com.
  2. Navigate to www.wandb.ai/authorize.
  3. Copy your API key for use in the next step.

Step 2 - Setting Up Your Kaggle Environment

Step 2.1 - Creating a New Kaggle Notebook


  1. Log in to Kaggle.
  2. Create a new Kaggle Notebook.
  3. Of course, the name of the notebook can be changed as desired. This guide uses the notebook name xls-r-300m-chilean-spanish-asr.

Step 2.2 - Setting Your WandB API Key

A Kaggle Secret will be used to securely store your WandB API key.


  1. Click Add-ons on the Kaggle Notebook main menu.
  2. Select Secret from the pop-up menu.
  3. Enter the label WANDB_API_KEY in the Label field and enter your WandB API key for the value.
  4. Ensure that the Attached checkbox to the left of the WANDB_API_KEY label field is checked.
  5. Click Done.

Step 2.3 - Adding the Training Datasets

The Chilean Spanish Speech Data Set has been uploaded to Kaggle as 2 distinct datasets:


Add both of these datasets to your Kaggle Notebook.

Step 3 - Building the Finetuning Notebook

The following 32 sub-steps build each of the finetuning notebook's 32 cells in order.

Step 3.1 - CELL 1: Installing Packages

The first cell of the finetuning notebook installs dependencies. Set the first cell to:


### CELL 1: Install Packages ###
!pip install --upgrade torchaudio
!pip install jiwer


  • The first line upgrades the torchaudio package to the latest version. torchaudio will be used to load audio files and resample audio data.
  • The second line installs the jiwer package which is required to use the HuggingFace Datasets library load_metric method used later.

Step 3.2 - CELL 2: Importing Python Packages

The second cell imports required Python packages. Set the second cell to:


### CELL 2: Import Python packages ###
import wandb
from kaggle_secrets import UserSecretsClient
import math
import re
import numpy as np
import pandas as pd
import torch
import torchaudio
import json
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass
from datasets import Dataset, load_metric, load_dataset, Audio
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments
from transformers import Trainer


  • You are probably already familiar with most of these packages. Their use in the notebook will be explained as subsequent cells are built.
  • It is worth mentioning that the HuggingFace transformers library and associated Wav2Vec2* classes provide the backbone of the functionality used for finetuning.

Step 3.3 - CELL 3: Loading WER Metric

The third cell imports the HuggingFace WER evaluation metric. Set the third cell to:


### CELL 3: Load WER metric ###
wer_metric = load_metric("wer")


  • As mentioned earlier, WER will be used to measure the performance of the model on evaluation/holdout data.

Step 3.4 - CELL 4: Logging into WandB

The fourth cell retrieves your WANDB_API_KEY secret that was set in Step 2.2. Set the fourth cell to:


### CELL 4: Login to WandB ###
user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
wandb.login(key = wandb_api_key)


  • The API key is used to configure the Kaggle Notebook so that training run data is sent to WandB.

Step 3.5 - CELL 5: Setting Constants

The fifth cell sets constants that will be used throughout the notebook. Set the fifth cell to:


### CELL 5: Constants ###
# Training data
TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/"
TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/"
EXT = ".wav"
NUM_LOAD_FROM_EACH_SET = 1600

# Vocabulary
VOCAB_FILE_PATH = "/kaggle/working/"
SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\‘\’\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]"

# Sampling rates
ORIG_SAMPLING_RATE = 48000
TGT_SAMPLING_RATE = 16000

# Training/validation data split
SPLIT_PCT = 0.10

# Model parameters
MODEL = "facebook/wav2vec2-xls-r-300m"
USE_SAFETENSORS = False

# Training arguments
OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr"
TRAIN_BATCH_SIZE = 18
EVAL_BATCH_SIZE = 10
TRAIN_EPOCHS = 30
SAVE_STEPS = 3200
EVAL_STEPS = 100
LOGGING_STEPS = 100
LEARNING_RATE = 1e-4
WARMUP_STEPS = 800


  • The notebook doesn't surface every conceivable constant in this cell. Some values that could be represented by constants have been left inline.
  • The use of many of the constants above should be self-evident. For those are not, their use will be explained in the following sub-steps.

Step 3.6 - CELL 6: Utility Methods for Reading Index Files, Cleaning Text, and Creating Vocabulary

The sixth cell defines utility methods for reading the dataset index files (see the Training Dataset sub-section above), as well as for cleaning transcription text and creating the vocabulary. Set the sixth cell to:


### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ###
def read_index_file_data(path: str, filename: str):
    data = []
    with open(path + filename, "r", encoding = "utf8") as f:
        lines = f.readlines()
        for line in lines:
            file_and_text = line.split("\t")
            data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")])
    return data

def truncate_training_dataset(dataset: list) -> list:
    if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower():
        return
    else:
        return dataset[:NUM_LOAD_FROM_EACH_SET]
    
def clean_text(text: str) -> str:
    cleaned_text = re.sub(SPECIAL_CHARS, "", text)
    cleaned_text = cleaned_text.lower()
    return cleaned_text

def create_vocab(data):
    vocab_list = []
    for index in range(len(data)):
        text = data[index][1]
        words = text.split(" ")
        for word in words:
            chars = list(word)
            for char in chars:
                if char not in vocab_list:
                    vocab_list.append(char)
    return vocab_list        


  • The read_index_file_data method reads a line_index.tsv dataset index file and produces a list of lists with audio filename and transcription data, e.g.:


[
    ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"]
    ...
]


  • The truncate_training_dataset method truncates a list index file data using the NUM_LOAD_FROM_EACH_SET constant set in Step 3.5. Specifically, the NUM_LOAD_FROM_EACH_SET constant is used to specify the number of audio samples that should be loaded from each dataset. For the purposes of this guide, the number is set at 1600 which means a total of 3200 audio samples will eventually be loaded. To load all samples, set NUM_LOAD_FROM_EACH_SET to the string value all.
  • The clean_text method is used to strip each text transcription of the characters specified by the regular expression assigned to SPECIAL_CHARS in Step 3.5. These characters, inclusive of punctuation, can be eliminated as they don't provide any semantic value when training the model to learn mappings between audio features and text transcriptions.
  • The create_vocab method creates a vocabulary from clean text transcriptions. Simply, it extracts all unique characters from the set of cleaned text transcriptions. You will see an example of the generated vocabulary in Step 3.14.

Step 3.7 - CELL 7: Utility Methods for Loading and Resampling Audio Data

The seventh cell defines utility methods using torchaudio to load and resample audio data. Set the seventh cell to:


### CELL 7: Utility methods for loading and resampling audio data ###
def read_audio_data(file):
    speech_array, sampling_rate = torchaudio.load(file, normalize = True)
    return speech_array, sampling_rate

def resample(waveform):
    transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE)
    waveform = transform(waveform)
    return waveform[0]


  • The read_audio_data method loads a specified audio file and returns a torch.Tensor multi-dimensional matrix of the audio data along with the sampling rate of the audio. All the audio files in the training data have a sampling rate of 48000 Hz. This "original" sampling rate is captured by the constant ORIG_SAMPLING_RATE in Step 3.5.
  • The resample method is used to downsample audio data from a sampling rate of 48000 to 16000. wav2vec2 is pretrained on audio sampled at 16000 Hz. Accordingly, any audio used for finetuning must have the same sampling rate. In this case, the audio examples must be downsampled from 48000 Hz to 16000 Hz. 16000 Hz is captured by the constant TGT_SAMPLING_RATE in Step 3.5.

Step 3.8 - CELL 8: Utility Methods to Prepare Data for Training

The eighth cell defines utility methods that process the audio and transcription data. Set the eighth cell to:


### CELL 8: Utility methods to prepare input data for training ###
def process_speech_audio(speech_array, sampling_rate):
    input_values = processor(speech_array, sampling_rate = sampling_rate).input_values
    return input_values[0]

def process_target_text(target_text):
    with processor.as_target_processor():
        encoding = processor(target_text).input_ids

    return encoding


  • The process_speech_audio method returns the input values from a supplied training sample.
  • The process_target_text method encodes each text transcription as a list of labels - i.e. a list of indices referring to characters in the vocabulary. You will see a sample encoding in Step 3.15.

Step 3.9 - CELL 9: Utility Method to Calculate Word Error Rate

The ninth cell is the final utility method cell and contains the method to calculate the Word Error Rate between a reference transcription and a predicted transcription. Set the ninth cell to:


### CELL 9: Utility method to calculate Word Error Rate
def compute_wer(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis = -1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens = False)

    wer = wer_metric.compute(predictions = pred_str, references = label_str)

    return {"wer": wer}

Step 3.10 - CELL 10: Reading Training Data

The tenth cell reads the training data index files for the recordings of male speakers and the recordings of female speakers using the read_index_file_data method defined in Step 3.6. Set the tenth cell to:


### CELL 10: Read training data ###
training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv")
training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")


  • As seen, the training data is managed in two gender-specific lists at this point. Data will be combined in Step 3.12 after truncation.

Step 3.11 - CELL 11: Truncating Training Data

The eleventh cell truncates the training data lists using the truncate_training_dataset method defined in Step 3.6. Set the eleventh cell to:


### CELL 11: Truncate training data ###
training_samples_male_cl = truncate_training_dataset(training_samples_male_cl)
training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)


  • As a reminder, the NUM_LOAD_FROM_EACH_SET constant set in Step 3.5 defines the quantity of samples to keep from each dataset. The constant is set to 1600 in this guide for a total of 3200 samples.

Step 3.12 - CELL 12: Combining Training Samples Data

The twelfth cell combines the truncated training data lists. Set the twelfth cell to:


### CELL 12: Combine training samples data ###
all_training_samples = training_samples_male_cl + training_samples_female_cl

Step 3.13 - CELL 13: Cleaning Transcription Test

The thirteenth cell iterates over each training data sample and cleans the associated transcription text using the clean_text method defined in Step 3.6. Set the thirteenth cell to:


for index in range(len(all_training_samples)):
    all_training_samples[index][1] = clean_text(all_training_samples[index][1])

Step 3.14 - CELL 14: Creating the Vocabulary

The fourteenth cell creates a vocabulary using the cleaned transcriptions from the previous step and the create_vocab method defined in Step 3.6. Set the fourteenth cell to:


### CELL 14: Create vocabulary ###
vocab_list = create_vocab(all_training_samples)
vocab_dict = {v: i for i, v in enumerate(vocab_list)}


  • The vocabulary is stored as a dictionary with characters as keys and vocabulary indices as values.

  • You can print vocab_dict which should produce the following output:


{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11,  'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}

Step 3.15 - CELL 15: Adding Word Delimiter to the Vocabulary

The fifteenth cell adds the word delimiter character | to the vocabulary. Set the fifteenth cell to:


### CELL 15: Add word delimiter to vocabulary ###
vocab_dict["|"] = len(vocab_dict)


  • The word delimiter character is used when tokenizing text transcriptions as a list of labels. Specifically, it is used to define the end of a word and it is used when initializing the Wav2Vec2CTCTokenizer class, as will be seen in Step 3.17.

  • For example, the following list encodes no te entiendo nada using the vocabulary from Step 3.14:


# Encoded text
[6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1]

# Vocabulary
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11,  'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}


  • A question that might naturally arise is: "Why is it necessary to define a word delimiter character?" For example, the end of words in written English and Spanish are marked by whitespace so it should be a simple matter to use the space character as a word delimiter. Remember that English and Spanish are just two languages among thousands; and not all written languages use a space to mark word boundaries.

Step 3.16 - CELL 16: Exporting Vocabulary

The sixteenth cell dumps the vocabulary to a file. Set the sixteenth cell to:


### CELL 16: Export vocabulary ###
with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file:
     json.dump(vocab_dict, vocab_file)


  • The vocabulary file will be used in the next step, Step 3.17, to initialize the Wav2Vec2CTCTokenizer class.

Step 3.17 - CELL 17: Initialize the Tokenizer

The seventeenth cell initializes an instance of Wav2Vec2CTCTokenizer. Set the seventeenth cell to:


### CELL 17: Initialize tokenizer ###
tokenizer = Wav2Vec2CTCTokenizer(
    VOCAB_FILE_PATH + "vocab.json",
    unk_token = "[UNK]",
    pad_token = "[PAD]",
    word_delimiter_token = "|",
    replace_word_delimiter_char = " "
)


  • The tokenizer is used for encoding text transcriptions and decoding a list of labels back to text.

  • Note that the tokenizer is initialized with [UNK] assigned to unk_token and [PAD] assigned to pad_token, with the former used to represent unknown tokens in text transcriptions and the latter used to pad transcriptions when creating batches of transcriptions with different lengths. These two values will be added to the vocabulary by the tokenizer.

  • Initialization of the tokenizer in this step will also add two additional tokens to the vocabulary, namely <s> and /</s> which are used to demarcate the beginning and end of sentences respectively.

  • | is assigned to word_delimiter_token explicitly in this step to reflect that the pipe symbol will be used to demarcate the end of words in accordance with our addition of the character to the vocabulary in Step 3.15. The | symbol is the default value for word_delimiter_token. So, it did not need to be explicitly set but was done so for the sake of clarity.

  • Similarly as with word_delimiter_token, a single space is explicitly assigned to replace_word_delimiter_char reflecting that the pipe symbol | will be used to replace blank space characters in text transcriptions. Blank space is the default value for replace_word_delimiter_char. So, it also did not need to be explicitly set but was done so for the sake of clarity.

  • You can print the full tokenizer vocabulary by calling the get_vocab() method on tokenizer.


vocab = tokenizer.get_vocab()
print(vocab)

# Output:
{'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}

Step 3.18 - CELL 18: Initializing the Feature Extractor

The eighteenth cell initializes an instance of Wav2Vec2FeatureExtractor. Set the eighteenth cell to:


### CELL 18: Initialize feature extractor ###
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size = 1,
    sampling_rate = 16000,
    padding_value = 0.0,
    do_normalize = True,
    return_attention_mask = True
)


  • The feature extractor is used to extract features from input data which is, of course, audio data in this use case. You will load the audio data for each training data sample in Step 3.20.
  • The parameter values passed to the Wav2Vec2FeatureExtractor initializer are all default values, with the exception of return_attention_mask which defaults to False. The default values are shown/passed for the sake of clarity.
  • The feature_size parameter specifies the dimension size of input features (i.e. audio data features). This default value of this parameter is 1.
  • sampling_rate tells the feature extractor the sampling rate at which the audio data should be digitalized. As discussed in Step 3.7, wav2vec2 is pretrained on audio sampled at 16000 Hz and hence 16000 is the default value for this parameter.
  • The padding_value parameter specifies the value that is used when padding audio data, as required when batching audio samples of different lengths. The default value is 0.0.
  • do_normalize is used to specify if input data should be transformed to a standard normal distribution. The default value is True. Wav2Vec2FeatureExtractor class documentation notes that "[normalizing] can help to significantly improve the performance for some models."
  • The return_attention_mask parameters specifies if the attention mask should be passed or not. The value is set to True for this use case.

Step 3.19 - CELL 19: Initializing the Processor

The nineteenth cell initializes an instance of Wav2Vec2Processor. Set the nineteenth cell to:


### CELL 19: Initialize processor ###
processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)


  • The Wav2Vec2Processor class combines tokenizer and feature_extractor from Step 3.17 and Step 3.18 respectively into a single processor.

  • Note that the processor configuration can be saved by calling the save_pretrained method on the Wav2Vec2Processor class instance.


processor.save_pretrained(OUTPUT_DIR_PATH)

Step 3.20 - CELL 20: Loading Audio Data

The twentieth cell loads each audio file specified in the all_training_samples list. Set the twentieth cell to:


### CELL 20: Load audio data ###
all_input_data = []

for index in range(len(all_training_samples)):
    speech_array, sampling_rate = read_audio_data(all_training_samples[index][0])
    all_input_data.append({
        "input_values": speech_array,
        "labels": all_training_samples[index][1]
    })


  • Audio data is returned as a torch.Tensor and stored in all_input_data as a list of dictionaries. Each dictionary contains the audio data for a particular sample, along with the text transcription of the audio.
  • Note that the read_audio_data method returns the sampling rate of the audio data as well. Since we know that the sampling rate is 48000 Hz for all audio files in this use case, the sampling rate is ignored in this step.

Step 3.21 - CELL 21: Converting all_input_data to a Pandas DataFrame

The twenty-first cell converts the all_input_data list to a Pandas DataFrame to make it easier to manipulate the data. Set the twenty-first cell to:


### CELL 21: Convert audio training data list to Pandas DataFrame ###
all_input_data_df = pd.DataFrame(data = all_input_data)

Step 3.22 - CELL 22: Processing Audio Data and Text Transcriptions

The twenty-second cell uses the processor initialized in Step 3.19 to extract features from each audio data sample and to encode each text transcription as a list of labels. Set the twenty-second cell to:


### CELL 22: Process audio data and text transcriptions ###
all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000))
all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))

Step 3.23 - CELL 23: Splitting Input Data into Training and Validation Datasets

The twenty-third cell splits the all_input_data_df DataFrame into training and evaluation (validation) datasets using the SPLIT_PCT constant from Step 3.5. Set the twenty-third cell to:


### CELL 23: Split input data into training and validation datasets ###
split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT)
valid_data_df = all_input_data_df.iloc[-split:]
train_data_df = all_input_data_df.iloc[:-split]


  • The SPLIT_PCT value is 0.10 in this guide meaning 10% of all input data will be held out for evaluation and 90% of the data will be used for training/finetuning.
  • Since there are a total of 3,200 training samples, 320 samples will be used for evaluation with the remaining 2,880 samples used to finetune the model.

Step 3.24 - CELL 24: Converting Training and Validation Datasets to Dataset Objects

The twenty-fourth cell converts the train_data_df and valid_data_df DataFrames to Dataset objects. Set the twenty-fourth cell to:


### CELL 24: Convert training and validation datasets to Dataset objects ###
train_data = Dataset.from_pandas(train_data_df)
valid_data = Dataset.from_pandas(valid_data_df)


  • Dataset objects are consumed by HuggingFace Trainer class instances, as you will see in Step 3.30.

  • These objects contain metadata about the dataset as well as the dataset itself.

  • You can print train_data and valid_data to view the metadata for both Dataset objects.


print(train_data)
print(valid_data)

# Output:
Dataset({
    features: ['input_values', 'labels'],
    num_rows: 2880
})
Dataset({
    features: ['input_values', 'labels'],
    num_rows: 320
})

Step 3.25 - CELL 25: Initializing the Pretrained Model

The twenty-fifth cell initializes the pretrained XLS-R (0.3) model. Set the twenty-fifth cell to:


### CELL 25: Initialize pretrained model ###
model = Wav2Vec2ForCTC.from_pretrained(
    MODEL,
    ctc_loss_reduction = "mean",
    pad_token_id = processor.tokenizer.pad_token_id,
    vocab_size = len(processor.tokenizer)
)


  • The from_pretrained method called on Wav2Vec2ForCTC specifies that we want to load the pretrained weights for the specified model.
  • The MODEL constant was specified in Step 3.5 and was set to facebook/wav2vec2-xls-r-300m reflecting the XLS-R (0.3) model.
  • The ctc_loss_reduction parameter specifies the type of reduction to apply to the output of the Connectionist Temporal Classification ("CTC") loss function. CTC loss is used to calculate the loss between a continuous input, in this case audio data, and a target sequence, in this case text transcriptions. By setting the value to mean, the output losses for a batch of inputs will be divided by the target lengths. The mean over the batch is then calculated and the reduction is applied to loss values.
  • pad_token_id specifies the token to be used for padding when batching. It is set to the [PAD] id set when initializing the tokenizer in Step 3.17.
  • The vocab_size parameter defines the vocabulary size of the model. It is the vocabulary size after initialization of the tokenizer in Step 3.17 and reflects the number of output layer nodes of the forward portion of the network.

Step 3.26 - CELL 26: Freezing Feature Extractor Weights

The twenty-sixth cell freezes the pretrained weights of the feature extractor. Set the twenty-sixth cell to:


### CELL 26: Freeze feature extractor ###
model.freeze_feature_extractor()

Step 3.27 - CELL 27: Setting Training Arguments

The twenty-seventh cell initializes the training arguments that will be passed to a Trainer instance. Set the twenty-seventh cell to:


### CELL 27: Set training arguments ###
training_args = TrainingArguments(
    output_dir = OUTPUT_DIR_PATH,
    save_safetensors = False,
    group_by_length = True,
    per_device_train_batch_size = TRAIN_BATCH_SIZE,
    per_device_eval_batch_size = EVAL_BATCH_SIZE,
    num_train_epochs = TRAIN_EPOCHS,
    gradient_checkpointing = True,
    evaluation_strategy = "steps",
    save_strategy = "steps",
    logging_strategy = "steps",
    eval_steps = EVAL_STEPS,
    save_steps = SAVE_STEPS,
    logging_steps = LOGGING_STEPS,
    learning_rate = LEARNING_RATE,
    warmup_steps = WARMUP_STEPS
)


  • The TrainingArguments class accepts more than 100 parameters.
  • The save_safetensors parameter when False specifies that the finetuned model should be saved to a pickle file instead of using the safetensors format.
  • The group_by_length parameter when True indicates that samples of approximately the same length should be grouped together. This minimizes padding and improves training efficiency.
  • per_device_train_batch_size sets the number of samples per training mini-batch. This parameter is set to 18 via the TRAIN_BATCH_SIZE constant assigned in Step 3.5. This implies 160 steps per epoch.
  • per_device_eval_batch_size sets the number of samples per evaluation (holdout) mini-batch. This parameter is set to 10 via the EVAL_BATCH_SIZE constant assigned in Step 3.5.
  • num_train_epochs sets the number of training epochs. This parameter is set to 30 via the TRAIN_EPOCHS constant assigned in Step 3.5. This implies 4,800 total steps during training.
  • The gradient_checkpointing parameter when True helps to save memory by checkpointing gradient calculations, but results in slower backward passes.
  • The evaluation_strategy parameter when set to steps means that evaluation will be performed and logged during training at an interval specified by the parameter eval_steps.
  • The logging_strategy parameter when set to steps means that training run statistics will be logged at an interval specified by the parameter logging_steps.
  • The save_strategy parameter when set to steps means that a checkpoint of the finetuned model will be saved at an interval specified by the parameter save_steps.
  • eval_steps sets the number of steps between evaluations of holdout data. This parameter is set to 100 via the EVAL_STEPS constant assigned in Step 3.5.
  • save_steps sets the number of steps after which a checkpoint of the finetuned model is saved. This parameter is set to 3200 via the SAVE_STEPS constant assigned in Step 3.5.
  • logging_steps sets the number of steps between logs of training run statistics. This parameter is set to 100 via the LOGGING_STEPS constant assigned in Step 3.5.
  • The learning_rate parameter sets the initial learning rate. This parameter is set to 1e-4 via the LEARNING_RATE constant assigned in Step 3.5.
  • The warmup_steps parameter sets the number of steps to linearly warmup the learning rate from 0 to the value set by learning_rate. This parameter is set to 800 via the WARMUP_STEPS constant assigned in Step 3.5.

Step 3.28 - CELL 28: Defining Data Collator Logic

The twenty-eighth cell defines the logic for dynamically padding input and target sequences. Set the twenty-eighth cell to:


### CELL 28: Define data collator logic ###
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding = self.padding,
            max_length = self.max_length,
            pad_to_multiple_of = self.pad_to_multiple_of,
            return_tensors = "pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding = self.padding,
                max_length = self.max_length_labels,
                pad_to_multiple_of = self.pad_to_multiple_of_labels,
                return_tensors = "pt",
            )

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch


  • Training and evaluation input-label pairs are passed in mini-batches to the Trainer instance that will be initialized momentarily in Step 3.30. Since the input sequences and label sequences vary in length in each mini-batch, some sequences must be padded so that they are all of the same length.
  • The DataCollatorCTCWithPadding class dynamically pads mini-batch data. The padding paramenter when set to True specifies that shorter audio input feature sequences and label sequences should have the same length as the longest sequence in a mini-batch.
  • Audio input features are padded with the value 0.0 set when initializing the feature extractor in Step 3.18.
  • Label inputs are first padded with the padding value set when initializing the tokenizer in Step 3.17. These values are replaced by -100 so that these labels are ignored when calculating the WER metric.

Step 3.29 - CELL 29: Initializing Instance of Data Collator

The twenty-ninth cell initializes an instance of the data collator defined in the previous step. Set the twenty-ninth cell to:


### CELL 29: Initialize instance of data collator ###
data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)

Step 3.30 - CELL 30: Initializing the Trainer

The thirtieth cell initializes an instance of the Trainer class. Set the thirtieth cell to:


### CELL 30: Initialize trainer ###
trainer = Trainer(
    model = model,
    data_collator = data_collator,
    args = training_args,
    compute_metrics = compute_wer,
    train_dataset = train_data,
    eval_dataset = valid_data,
    tokenizer = processor.feature_extractor
)


  • As seen, the Trainer class is initialized with:
    • The pretrained model initialized in Step 3.25.
    • The data collator initialized in Step 3.29.
    • The training arguments initialized in Step 3.27.
    • The WER evaluation method defined in Step 3.9.
    • The train_data Dataset object from Step 3.24.
    • The valid_data Dataset object from Step 3.24.
  • The tokenizer parameter is assigned to processor.feature_extractor and works with data_collator to automatically pad the inputs to the maximum-length input of each mini-batch.

Step 3.31 - CELL 31: Finetuning the Model

The thirty-first cell calls the train method on the Trainer class instance to finetune the model. Set the thirty-first cell to:


### CELL 31: Finetune the model ###
trainer.train()

Step 3.32 - CELL 32: Save the finetuned model

The thirty-second cell is the last notebook cell. It saves the finetuned model by calling the save_model method on the Trainer instance. Set the thirty-second cell to:


### CELL 32: Save the finetuned model ###
trainer.save_model(OUTPUT_DIR_PATH)

Step 4 - Training and Saving the Model

Step 4.1 - Training the Model

Now that all the cells of the notebook have been built, it’s time to start finetuning.


  1. Set the Kaggle Notebook to run with the NVIDIA GPU P100 accelerator.

  2. Commit the notebook on Kaggle.

  3. Monitor training run data by logging into your WandB account and locating the associated run.


Training over 30 epochs should take ~5 hours using the NVIDIA GPU P100 accelerator. The WER on holdout data should drop to ~0.15 at the end of training. It’s not quite a state-of-the-art result, but the finetuned model is still sufficiently useful for many applications.

Step 4.2 - Saving the Model

The finetuned model will be output to the Kaggle directory specified by the constant OUTPUT_DIR_PATH specified in Step 3.5. The model output should include the following files:


pytorch_model.bin
config.json
preprocessor_config.json
vocab.json
training_args.bin


These files can be downloaded locally. Additionally, you can create a new Kaggle Model using the model files. The Kaggle Model will be used with the companion inference guide to run inference on the finetuned model.


  1. Log in to your Kaggle account. Click on Models > New Model.
  2. Add a title for your finetuned model in the Model Title field.
  3. Click on Create Model.
  4. Click on Go to model detail page.
  5. Click on Add new variation under Model Variations.
  6. Select Transformers from the Framework select menu.
  7. Click on Add new variation.
  8. Drag and drop your finetuned model files into the Upload Data window. Alternatively, click on the Browse Files button to open a file explorer window and select your finetuned model files.
  9. Once the files have uploaded to Kaggle, click on Create to create the Kaggle Model.

Conclusion

Congratulations on finetuning wav2vec2 XLS-R! Remember that you can use these general steps to finetune the model on other languages that you desire. Running inference on the finetuned model generated in this guide is fairly straightforward. The inference steps will be outlined in a separate companion guide to this one. Please search on my HackerNoon username to find the companion guide.