paint-brush
使用 wav2vec2 第 2 部分 - 在经过微调的 ASR 模型上运行推理经过@pictureinthenoise
436 讀數
436 讀數

使用 wav2vec2 第 2 部分 - 在经过微调的 ASR 模型上运行推理

经过 Picture in the Noise11m2024/05/07
Read on Terminal Reader

太長; 讀書

本配套指南介绍了在经过微调的 wav2vec2 XLS-R 模型上运行推理的步骤。它是指南“使用 wav2vec2 第 1 部分 - 微调 XLS-R 进行自动语音识别”的补充。该指南提供了有关创建可用于运行推理的 Kaggle Notebook 的分步说明。
featured image - 使用 wav2vec2 第 2 部分 - 在经过微调的 ASR 模型上运行推理
Picture in the Noise HackerNoon profile picture
0-item
1-item

介绍

这是使用 wav2vec2 第 1 部分 - 微调 XLS-R 以进行自动语音识别(“第 1 部分指南”)的配套指南。我编写了第 1 部分指南,介绍如何针对智利西班牙语微调 Meta AI 的wav2vec2 XLS-R(“XLS-R”)模型。假设您已完成该指南并生成了您自己的微调 XLS-R 模型。本指南将解释通过Kaggle Notebook对微调后的 XLS-R 模型进行推理的步骤。

先决条件和开始之前

要完成本指南,您需要具备:


  • 针对西班牙语进行微调的 XLS-R 模型。
  • 现有的Kaggle 帐户
  • 具有中级 Python 知识。
  • 具有使用 Kaggle Notebooks 的中级知识。
  • 对 ML 概念有中级了解。
  • ASR 概念的基本知识。

构建推理笔记本

步骤 1 - 设置你的 Kaggle 环境

步骤 1.1 - 创建新的 Kaggle Notebook

  1. 登录 Kaggle。
  2. 创建一个新的 Kaggle Notebook。
  3. 可以根据需要更改笔记本的名称。本指南使用笔记本名称spanish-asr-inference

步骤 1.2 - 添加测试数据集

本指南使用秘鲁西班牙语语音数据集作为测试数据来源。与智利西班牙语语音数据集一样,秘鲁语者数据集也由两个子数据集组成:2,918 条秘鲁男性说话者的录音和 2,529 条秘鲁女性说话者的录音。


该数据集已作为 2 个不同的数据集上传至 Kaggle:


单击添加输入,将这两个数据集都添加到您的 Kaggle Notebook 中。

步骤 1.3 - 添加微调模型

您应该在使用 wav2vec2 第 1 部分 - 对 XLS-R 进行自动语音识别微调指南的第 4 步中将微调后的模型保存为Kaggle 模型


单击“添加输入” ,将微调后的模型添加到你的 Kaggle Notebook 中。

第 2 步 - 构建推理笔记本

以下 16 个子步骤按顺序构建推理笔记本的 16 个单元。您会注意到,这里使用了许多与第 1 部分指南相同的实用方法。

步骤 2.1 - CELL 1:安装包

推理笔记本的第一个单元格安装依赖项。将第一个单元格设置为:


 ### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer

步骤 2.2 - CELL 2:导入 Python 包

第二个单元格导入所需的 Python 包。将第二个单元格设置为:


 ### CELL 2: Import Python packages ### import re import math import random import pandas as pd import torchaudio from datasets import load_metric from transformers import pipeline

步骤 2.3 - 单元格 3:加载 WER 指标

第三个单元格导入 HuggingFace WER 评估指标。将第三个单元格设置为:


 ### CELL 3: Load WER metric ### wer_metric = load_metric("wer")


  • WER 将用于衡量微调模型在测试数据上的性能。

步骤 2.4 - 单元格 4:设置常数

第四个单元格设置将在整个笔记本中使用的常量。将第四个单元格设置为:


 ### CELL 4: Constants ### # Testing data TEST_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-peru-male/" TEST_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-peru-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 3 # Special characters SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000

步骤 2.5 - CELL 5:读取索引文件、清理文本和创建词汇表的实用方法

第五个单元格定义了用于读取数据集索引文件以及清理转录文本和从测试数据中生成一组随机样本的实用方法。将第五个单元格设置为:


 ### CELL 5: Utility methods for reading index files, cleaning text, random indices generator ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def get_random_samples(dataset: list, num: int) -> list: used = [] samples = [] for i in range(num): a = -1 while a == -1 or a in used: a = math.floor(len(dataset) * random.random()) samples.append(dataset[a]) used.append(a) return samples


  • read_index_file_data方法读取line_index.tsv数据集索引文件并生成包含音频文件名和转录数据的列表列表,例如:


 [ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]


  • clean_text方法用于删除每个文本转录中步骤 2.4中分配给SPECIAL_CHARS的正则表达式所指定的字符。这些字符(包括标点符号)可以被删除,因为它们在训练模型以学习音频特征和文本转录之间的映射时不提供任何语义价值。
  • get_random_samples方法返回一组随机测试样本,其数量由步骤 2.4中的常量NUM_LOAD_FROM_EACH_SET设置。

步骤 2.6 - 单元 6:用于加载和重采样音频数据的实用方法

第六个单元格定义使用torchaudio加载和重新采样音频数据的实用方法。将第六个单元格设置为:


 ### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]


  • read_audio_data方法加载指定的音频文件并返回音频数据的torch.Tensor多维矩阵以及音频的采样率。训练数据中的所有音频文件的采样率均为48000 Hz。此“原始”采样率由步骤 2.4中的常量ORIG_SAMPLING_RATE捕获。
  • resample方法用于将音频数据从采样率48000下采样到目标采样率16000

步骤 2.7 - 单元 7:读取测试数据

第七个单元格使用步骤 2.5中定义的read_index_file_data方法读取男性说话者录音和女性说话者录音的测试数据索引文件。将第七个单元格设置为:


 ### CELL 7: Read test data ### test_data_male = read_index_file_data(TEST_DATA_PATH_MALE, "line_index.tsv") test_data_female = read_index_file_data(TEST_DATA_PATH_FEMALE, "line_index.tsv")

步骤 2.8 - 单元格 8:生成随机测试样本列表

第八个单元格使用步骤 2.5中定义的get_random_samples方法生成随机测试样本集。将第八个单元格设置为:


 ### CELL 8: Generate lists of random test samples ### random_test_samples_male = get_random_samples(test_data_male, NUM_LOAD_FROM_EACH_SET) random_test_samples_female = get_random_samples(test_data_female, NUM_LOAD_FROM_EACH_SET)

步骤 2.9 - 单元格 9:合并测试数据

第九个单元格将男性测试样本和女性测试样本合并为一个列表。将第九个单元格设置为:


 ### CELL 9: Combine test data ### all_test_samples = random_test_samples_male + random_test_samples_female

步骤 2.10 - 单元格 10:清理转录测试

第十个单元格对每个测试数据样本进行迭代,并使用步骤 2.5中定义的clean_text方法清理相关的转录文本。将第十个单元格设置为:


 ### CELL 10: Clean text transcriptions ### for index in range(len(all_test_samples)): all_test_samples[index][1] = clean_text(all_test_samples[index][1])

步骤 2.11 - 单元格 11:加载音频数据

第 11 个单元格加载all_test_samples列表中指定的每个音频文件。将第 11 个单元格设置为:


 ### CELL 11: Load audio data ### all_test_data = [] for index in range(len(all_test_samples)): speech_array, sampling_rate = read_audio_data(all_test_samples[index][0]) all_test_data.append({ "raw": speech_array, "sampling_rate": sampling_rate, "target_text": all_test_samples[index][1] })


  • 音频数据以torch.Tensor的形式返回,并以字典列表的形式存储在all_test_data中。每个字典包含特定样本的音频数据、采样率和音频的文本转录。

步骤 2.12 - 单元格 12:重新采样音频数据

第十二个单元格将音频数据重新采样为目标采样率16000 。将第十二个单元格设置为:


 ### CELL 12: Resample audio data and cast to NumPy arrays ### all_test_data = [{"raw": resample(sample["raw"]).numpy(), "sampling_rate": TGT_SAMPLING_RATE, "target_text": sample["target_text"]} for sample in all_test_data]

步骤 2.13 - 单元 13:初始化自动语音识别管道实例

第十三个单元格初始化 HuggingFace transformerpipeline类的实例。将第十三个单元格设置为:


 ### CELL 13: Initialize instance of Automatic Speech Recognition Pipeline ### transcriber = pipeline("automatic-speech-recognition", model = "YOUR_FINETUNED_MODEL_PATH")


  • model参数必须设置为步骤 1.3中添加到 Kaggle Notebook 的微调模型的路径,例如:


 transcriber = pipeline("automatic-speech-recognition", model = "/kaggle/input/xls-r-300m-chilean-spanish/transformers/hardy-pine/1")

步骤 2.14 - 单元格 14:生成预测

第 14 个单元格在测试数据上调用上一步初始化的transcriber来生成文本预测。将第 14 个单元格设置为:


 ### CELL 14: Generate transcriptions ### transcriptions = transcriber(all_test_data)

步骤 2.15 - 单元格 15:计算 WER 指标

第十五个单元格计算每个预测的 WER 分数以及所有预测的总体 WER 分数。将第十五个单元格设置为:


 ### CELL 15: Calculate WER metrics ### predictions = [transcription["text"] for transcription in transcriptions] references = [transcription["target_text"][0] for transcription in transcriptions] wers = [] for p in range(len(predictions)): wer = wer_metric.compute(predictions = [predictions[p]], references = [references[p]]) wers.append(wer) zipped = list(zip(predictions, references, wers)) df = pd.DataFrame(zipped, columns=["Prediction", "Reference", "WER"]) wer = wer_metric.compute(predictions = predictions, references = references)

步骤 2.16 - 单元格 16:打印 WER 指标

第十六个单元格(也是最后一个单元格)仅打印上一步中的 WER 计算结果。将第十六个单元格设置为:


 ### CELL 16: Output WER metrics ### pd.set_option("display.max_colwidth", None) print(f"Overall WER: {wer}") print(df)

WER 分析

由于笔记本会根据测试数据的随机样本生成预测,因此每次运行笔记本时输出都会有所不同。在运行笔记本时, NUM_LOAD_FROM_EACH_SET设置为3 ,总共 6 个测试样本,生成了以下输出:


 Overall WER: 0.013888888888888888 Prediction \ 0 quiero que me reserves el mejor asiento del teatro 1 el llano en llamas es un clásico de juan rulfo 2 el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera 3 hay tres cafés que están abiertos hasta las once de la noche 4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras 5 cuántos albergues se abrieron después del terremoto del diecinueve de setiembre Reference \ 0 quiero que me reserves el mejor asiento del teatro 1 el llano en llamas es un clásico de juan rulfo 2 el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera 3 hay tres cafés que están abiertos hasta las once de la noche 4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras 5 cuántos albergues se abrieron después del terremoto del diecinueve de septiembre WER 0 0.000000 1 0.000000 2 0.000000 3 0.000000 4 0.000000 5 0.090909


可以看出,该模型表现非常出色!它只在第六个样本(索引5 )中犯了一个错误,将单词septiembre拼写为setiembre 。当然,使用不同的测试样本(更重要的是,使用更多的测试样本)再次运行该笔记本,将产生不同且更具信息量的结果。尽管如此,这些有限的数据表明该模型可以在不同的西班牙语方言上表现良好 - 例如,它是用智利西班牙语训练的,但似乎在秘鲁西班牙语上表现良好。

结论

如果您刚刚开始学习如何使用 wav2vec2 模型,我希望《使用 wav2vec2 第 1 部分 - 微调 XLS-R 进行自动语音识别》指南和本指南对您有所帮助。如前所述,第 1 部分指南生成的微调模型并不是最先进的,但仍可用于许多应用。祝您构建愉快!