运用 Transformers 为多语种语音辨认使命微调 Whisper 模型
本文供给了一个运用 Hugging Face Transformers 在恣意多语种语音辨认 (ASR) 数据集上微调 Whisper 的分步指南。同时,咱们还深化解释了 Whisper 模型、Common Voice 数据集以及微调等理论常识,并供给了数据预备和微调的相关代码。假如你想要一个全部是代码,仅有少数解释的 Notebook,能够参阅这个 Google Colab。
目录
- 简介
-
在 Google Colab 中微调 Whisper
- 预备环境
- 加载数据集
- 预备特征提取器、分词器和数据
- 练习与评价
- 构建演示运用
- 结束语
简介
Whisper 是一系列用于自动语音辨认 (automatic speech recognition,ASR) 的预练习模型,它由来自于 OpenAI 的 Alec Radford 等人于 2022 年 9 月 发布。与 Wav2Vec 2.0 等前作不同,以往的模型都是在未标示的音频数据上预练习的,而 Whisper 是在很多的 已标示 音频转录数据上预练习的。其用于练习的标示音频时长高达 68 万小时,比 Wav2Vec 2.0 运用的未标示练习数据 (6 万小时) 还多一个数量级。更妙的是,该预练习数据中还含有 11.7 万小时的多语种数据。因而,Whisper 训得的 checkpoint 可运用于超越 96 种言语,这其间包含不少 数据匮乏 的小语种。
这么多的标示数据使得咱们能够直接在 有监督 语音辨认使命上预练习 Whisper,从标示音频转录数据 1{}^1 中直接习得语音到文本的映射。因而,Whisper 简直不需求额定的微调就现已是高性能的 ASR 模型了。这让 Wav2Vec 2.0 相形见绌,由于 Wav2Vec 2.0 是在 无监督 掩码猜测使命上预练习的,所以其训得的模型仅从未标示的纯音频数据中习得了从语音到隐含状况的中心映射。虽然无监督预练习能产生高质量的语音表征,但它 学不到语音到文本的映射,要学到语音到文本的映射只能靠微调。因而,Wav2Vec 2.0 需求更多的微调才能取得较有竞争力的性能。
在 68 万小时标示数据的加持下,预练习 Whisper 模型体现出了强壮的泛化到多种数据集和范畴的才能。其预练习 checkpoint 体现出了与最先进的 ASR 体系旗鼓相当的性能: 在 LibriSpeech ASR 的无噪测验子集上的单词错误率 (word error rate,WER) 仅为约 3%,别的它还在 TED-LIUM 上创下了新的记载 – 4.7% 的 WER ( 详见 Whisper 论文 的表 8)。Whisper 在预练习期间取得的广泛的多语种 ASR 常识对一些数据匮乏的小语种特别有用。稍稍微调一下,预练习 checkpoint 就能够进一步适配特定的数据集和语种,然后进一步改进在这些语种上的辨认作用。
Whisper 是一个依据 transformer 的编码器 – 解码器模型 (也称为 序列到序列 模型),它将音频的频谱图特征 序列 映射到文本的词 序列。首要,经过特征提取器将原始音频输入变换为对数梅尔声谱图 (log-Mel spectrogram)。然后,transformer 编码器对声谱图进行编码,生成一系列编码器隐含状况。终究,解码器依据先前输出的词以及编码器隐含状况,自回归地猜测下一个输出词。图 1 是 Whisper 模型的示意图。
在序列到序列模型中,编码器担任从语音中提取出重要特征,将输入转化为一组隐含状况表征。解码器扮演言语模型的角色,处理隐含状况表征并生成对应的文本。咱们把在模型架构 内部 集成言语模型的做法称为 深度交融。与之相对的是 浅交融,此时,言语模型在 外部与编码器组合,如 CTC + nn-gram ( 详见 Internal Language Model Estimation 一文)。经过深度交融,能够用同一份练习数据和损失函数对整个体系进行端到端练习,然后取得更大的灵活性和更优越的性能 ( 详见 ESB Benchmark)。
Whisper 运用穿插熵方针函数进行预练习和微调,穿插熵方针函数是练习序列标示模型的规范方针函数。经过练习,模型能够正确地对方针词进行分类,然后从预界说的词汇表中选出输出词。
Whisper 有五种不同尺度的 checkpoint。其间,四个小尺度 checkpoint 又各有两个版别: 英语版和多语种版,而最大的 checkpoint 只有多语种版。一切九个预练习 checkpoints 都能够在 Hugging Face Hub 上找到。下表总结了这些 checkpoint 的信息及其 Hub 链接:
尺度 | 层数 | 宽 | 多头留意力的头数 | 参数量 | 英语 checkpoint | 多语种 checkpoint |
---|---|---|---|---|---|---|
tiny | 4 | 384 | 6 | 39 M | ✓ | ✓ |
base | 6 | 512 | 8 | 74 M | ✓ | ✓ |
small | 12 | 768 | 12 | 244 M | ✓ | ✓ |
medium | 24 | 1024 | 16 | 769 M | ✓ | ✓ |
large | 32 | 1280 | 20 | 1550 M | x | ✓ |
下面,咱们将以多语种版的 small
checkpoint (参数量 244M (~= 1GB)) 为例,带大家走一遍微调模型的全过程。咱们将运用 Common Voice 数据集里的小语种数据来练习和评价咱们的体系。经过这个比如,咱们将证明,仅需 8 小时的练习数据就能够微调出一个在该语种上体现强壮的语音辨认模型。
1{}^1 Whisper 的称号来自于 “Web-scale Supervised Pre-training for Speech Recognition (网络规模的有监督语音辨认预练习模型)” 的首字母缩写 “WSPSR”。
在 Google Colab 中微调 Whisper
预备环境
在微调 Whisper 模型时,咱们会用到几个盛行的 Python 包。咱们运用 datasets
来下载和预备练习数据,运用 transformers
来加载和练习 Whisper 模型。别的,咱们还需求 soundfile
包来预处理音频文件,evaluate
和 jiwer
来评价模型的性能。终究,咱们用 gradio
来为微调后的模型构建一个亮晶晶的演示运用。
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
咱们强烈建议你直接将训得的模型 checkpoint 上传到 Hugging Face Hub。Hub 供给了以下功用:
- 集成版别控制: 确保在练习期间不会丢失任何模型 checkpoint。
- Tensorboard 日志: 跟踪练习过程中的重要目标。
- 模型卡: 记载模型的用法及其运用场景。
- 社区: 轻松与社区进行分享和协作!
将 Python notebook 连上 Hub 十分简单 – 只需依据提示输入你的 Hub 身份验证令牌即可。你能够在 此处 找到你自己的 Hub 身份验证令牌:
from huggingface_hub import notebook_login
notebook_login()
打印输出:
Login successful
Your token has been saved to /root/.huggingface/token
加载数据集
Common Voice 由一系列众包数据集组成,其间包含了用各种言语录制的维基百科文本。本文运用的是最新版别的 Common Voice 数据集 (版别号为 11)。语种上,咱们挑选用 印地语 来微调咱们的模型。印地语是一种在印度北部、中部、东部和西部运用的印度 – 雅利安语。Common Voice 11.0 中有大约 12 小时的标示印地语数据,其间 4 小时是测验数据。
咱们先看下 Hub 上的 Common Voice 数据集页面: mozilla-foundation/common_voice_11_0。假如你是首次检查此页面,体系会要求你承受其运用条款,赞同后就能够拜访数据集了。
一旦身份验证成功,你就会看到数据集预览。数据集预览展现了数据集的前 100 个样本。更重要的是,它还加载了可供实时收听的音频。咱们能够在下拉菜单挑选 hi
来挑选 Common Voice 的印地语子集 ( hi
是印地语的言语标识符代码):
点击第一个音频的播放按钮,你就能够收听音频并看到相应的文本了。你还能够翻滚阅读练习集和测验会集的样本,以更好地了解待处理音频和文本数据。从语调和风格能够看出,这些音频是旁白录音。你或许还会留意到录音者和录音质量的巨大差异,这是众包数据的一个共同特征。
运用 Datasets 来下载和预备数据十分简单。仅需一行代码即可完结 Common Voice 数据集的下载和预备作业。由于印地语数据十分匮乏,咱们把 练习集
和 验证集
兼并成约 8 小时的练习数据,而测验则依据 4 小时的 测验集
:
from datasets import load_dataset, DatasetDict
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True)
print(common_voice)
打印输出:
DatasetDict({
train: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 6540
})
test: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 2894
})
})
大多数 ASR 数据集仅包含输入音频样本 ( audio
) 和相应的转录文本 ( sentence
)。 Common Voice 还包含额定的元信息,例如 accent
和 locale
,在 ASR 场景中,咱们能够疏忽这些信息。为了使代码尽或许通用,咱们只考虑依据输入音频和转录文本进行微调,而不运用额定的元信息:
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
除了 Common Voice,Hub 上还有不少其他多语种 ASR 数据集可供运用,你能够点击链接: Hub 上的 ASR 数据集 了解更多。
预备特征提取器、分词器和数据
ASR 的流水线首要包含三个模块:
- 对原始音频输入进行预处理的特征提取器
- 履行序列到序列映射的模型
- 将模型输出转化为文本的分词器
在 Transformers 中,Whisper 模型有自己的特征提取器和分词器,即 WhisperFeatureExtractor 和 WhisperTokenizer。
下面,咱们逐一详细介绍特征提取器和分词器!
加载 WhisperFeatureExtractor
语音可表明为随时间改变的一维数组,给定时间的数组值即表明信号在该时间的 起伏,而咱们能够仅从起伏信息重建音频的频谱并康复其一切声学特征。
由于语音是接连的,因而它包含无数个起伏值,而核算机只能表明并存储有限个值。因而,咱们需求经过对语音信号进行离散化,即以固定的时间距离对接连信号进行 采样。咱们将每秒采样的次数称为 采样率,通常以样本数/秒或 赫兹 (Hz) 为单位。高采样率能够更好地逼近接连语音信号,但同时每秒所需的存储量也更大。
需求特别留意的是,输入音频的采样率需求与模型期望的采样率相匹配,由于不同采样率的音频信号的散布是不同的。处理音频时,需求运用正确的采样率,不然或许会引起意想不到的成果!例如,以 16kHz 的采样率采集音频但以 8kHz 的采样率收听它,会使音频听起来好像是半速的。同样地,向一个需求某一采样率的 ASR 模型馈送一个错误采样率的音频也会影响模型的性能。Whisper 特征提取器需求采样率为 16kHz 的音频输入,因而输入的采样率要与之相匹配。咱们不想无意中用慢速语音来练习 ASR!
Whisper 特征提取器履行两个操作。首要,填充或截断一批音频样本,将一切样本的输入长度一致至 30 秒。经过在序列结束添加零 (音频信号中的零对应于无信号或静音),将短于 30 秒的样本填充到 30 秒。而对超越 30 秒的样本,直接截断为 30 秒就好了。由于这一批数据中的一切样本都被填充或截断到一致长度 (即 30 s) 了,因而将音频馈送给 Whisper 模型时就不需求留意力掩码了。这是 Whisper 的独门特性,其他大多数音频模型都需求用户供给一个留意力掩码,详细说明填充方位,这样模型才能在自留意力机制中疏忽填充部分。经过练习的 Whisper 模型能够直接从语音信号中推断出应该疏忽哪些部分,因而无需留意力掩码。
Whisper 特征提取器履行的第二个操作是将第一步所得的音频变换为对数梅尔声谱图。这些频谱图是信号频率的直观表明,类似于傅里叶变换。图 2 展现了一个声谱图的比如,其间 yy 轴表明梅尔频段 (Mel channel),对应于特定的频段,xx 轴表明时间,色彩对应于给定时间该频段的对数强度。Whisper 模型要求输入为对数梅尔声谱图。
梅尔频段是语音处理的规范办法,研究人员用它来近似表明人类的听觉范围。对于 Whisper 微调这个使命而言,咱们只需求知道声谱图是语音信号中频率的直观表明。更多有关梅尔频段的详细信息,请参阅 梅尔倒谱 一文。
幸运的是, Transformers Whisper 特征提取器仅用一行代码即可履行填充和声谱图变换两个操作!咱们运用以下代码从预练习的 checkpoint 中加载特征提取器,为音频数据处理做好预备:
from transformers import WhisperFeatureExtractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
加载 WhisperTokenizer
现在咱们加载 Whisper 分词器。Whisper 模型会输出词元,这些词元表明猜测文本在词典中的索引。分词器担任将这一系列词元映射为终究的文本字符串 (例如 [1169, 3797, 3332] -> “the cat sat”)。
曩昔,当运用编码器模型进行 ASR 时,咱们需运用 连接时序分类法 (Connectionist Temporal Classification,CTC) 进行解码。在运用 CTC 进行解码时,咱们需求为每个数据集练习一个 CTC 分词器。但运用编码器 – 解码器架构的一个优势是咱们能够直接运用预练习模型的分词器。
Whisper 分词器在 96 种语种数据上预练习而得,因而,其 字节对 (byte-pair) 覆盖面很广,简直包含了一切语种。就印地语而言,咱们能够加载分词器并将其直接用于微调。仅需指定一下方针语种和使命,分词器就会依据这些参数将语种和使命符号添加为输出序列的前缀:
from transformers import WhisperTokenizer
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
咱们能够经过对 Common Voice 数据集的第一个样本进行编解码来验证分词器是否正确编码了印地语字符。在对转录文本进行编码时,分词器在序列的最初和结束添加“特别符号”,其间包含文本的开端/结束、语种符号和使命符号 (由上一步中的参数指定)。在解码时,咱们能够挑选“跳过”这些特别符号,然后保证输出是纯文本方式的:
input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)
print(f"Input: {input_str}")
print(f"Decoded w/ special: {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal: {input_str == decoded_str}")
打印输出:
Input: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Decoded w/ special: <|startoftranscript|><|hi|><|transcribe|><|notimestamps|>खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई<|endoftext|>
Decoded w/out special: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Are equal: True
组装一个 WhisperProcessor
为了简化运用,咱们能够将特征提取器和分词器 包进 到一个 WhisperProcessor
类,该类继承自 WhisperFeatureExtractor
及 WhisperTokenizer
,可依据需求用于音频处理和模型猜测。有了它,咱们在练习期间只需求保存两个目标: processor
和 model
就好了。
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
预备数据
咱们把 Common Voice 数据集的第一个样本打印出来,看看数据长什么样:
print(common_voice["train"][0])
打印输出:
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 9.6724887e-07,
1.5334779e-06, 1.0415988e-06], dtype=float32),
'sampling_rate': 48000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
能够看到,样本含有一个一维音频数组及其对应的转录文本。上文现已多次谈及采样率,以及将音频的采样率与 Whisper 模型所需的采样率 (16kHz) 相匹配的重要性。由于现在输入音频的采样率为 48kHz,所以在将其馈送给 Whisper 特征提取器之前,咱们需求将其 _下采样_至 16kHz。
咱们将运用 dataset
的 cast_column
办法将输入音频转化至所需的采样率。该办法仅指示 datasets
让其在首次加载音频时 _即时地_对数据进行重采样,因而并不会改变原音频数据:
from datasets import Audio
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
从头打印下 Common Voice 数据会集的第一个音频样本,能够看到其已被重采样:
print(common_voice["train"][0])
打印输出:
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-3.4206650e-07, 3.2979898e-07, 1.0042874e-06], dtype=float32),
'sampling_rate': 16000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
酷!咱们能够看到音频已被下采样到 16kHz 了。数组里边的值也变了,现在的 1 个起伏值大致对应于之前的 3 个起伏值。
现在咱们编写一个函数来为模型预备数据:
- 调用
batch["audio"]
加载和重采样音频数据。如上所述, Datasets 会即时履行任何必要的重采样操作。 - 运用特征提取器将一维音频数组变换为对数梅尔声谱图特征。
- 运用分词器将录音文本编码为 ID。
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
咱们能够用 dataset
的 .map
办法在一切练习样本上运用上述函数:
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)
好了!练习数据预备结束!咱们继续看看怎么运用这些数据来微调 Whisper。
留意: 目前 datasets
首要运用 torchaudio
和 [librosa
](librosa.org /doc/latest/index.html) 来进行音频加载和重采样。假如你自己定制一个数据加载/采样函数的话,你完全能够直接经过 "path"
列获取音频文件路径而不用管 "audio"
列。
练习与评价
至此,数据已预备结束,能够开端练习了。练习的大部分深重的作业都会由 Trainer 来完结。咱们要做的首要有:
- 界说数据收拾器 (data collator): 数据收拾器获取预处理后的数据并将其转化为 PyTorch 张量。
- 评价目标: 咱们运用 单词错误率 (word error rate,WER) 目标来评价模型,因而需求界说一个
compute_metrics
函数来核算它。 - 加载预练习 checkpoint: 咱们需求加载预练习 checkpoint 并正确配置它以进行练习。
- 界说练习参数: Trainer 在制订练习计划时需求用到这些参数。
微调完后,咱们需求运用测验数据对其进行评价,以验证终究模型在印地语上的语音辨认作用。
界说数据收拾器
序列到序列语音模型的数据收拾器与其他使命有所不同,由于 input_features
和 labels
的处理办法是不同的: input_features
必须由特征提取器处理,而 labels
由分词器处理。
input_features
现已填充至 30s 并转化为固定维度的对数梅尔声谱图,咱们所要做的只剩将其转化为 PyTorch 张量。咱们用特征提取器的 .pad
办法来完结这一功用,且将其入参设为 return_tensors=pt
。请留意,这儿不需求额定的填充,由于输入维度现已固定了,所以咱们只需求简单地将 input_features
转化为 PyTorch 张量就好了。
另一方面,labels
数据之前并未填充。所以,咱们首要要运用分词器的 .pad
办法将序列填充至本 batch 的最大长度。然后将填充符号替换为 -100
,这样它们就能够 不 用参与损失的核算了。然后咱们把 SOT
从序列的最初去掉,稍后练习的时候咱们再把它加回来。
咱们能够运用之前界说的 WhisperProcessor
来履行特征提取和分词操作:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
咱们初始化一下刚刚界说的数据收拾器:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
评价目标
接下来要界说评价目标。咱们将运用词错误率 (WER) 目标,它是评价 ASR 体系的“规范”目标。有关其详细信息,请参阅 WER 文档。下面,咱们从 Evaluate 中加载 WER 目标:
import evaluate
metric = evaluate.load("wer")
然后咱们只需求界说一个函数来承受模型输出并回来 WER 目标。这个名为 compute_metrics
的函数首要将 -100
替换为 label_ids
中的 pad_token_id
(以便在核算损失时将其疏忽)。然后,将猜测到的 ID 和 label_ids
解码为字符串文本。终究,核算输出文本和实在文本之间的 WER:
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
加载预练习 checkpoint
现在咱们加载预练习 Whisper small
模型的 checkpoint。同样,能够经过运用 transformers 很轻松地完结这一步!
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
原始 Whisper 模型在自回归生成开端之前强制添加了若干前缀词元 ID (forced_decoder_ids
)。这些词元 ID 首要用于在零样本 ASR 使命中标识语种和使命。由于咱们现在是对已知语种 (印地语) 和使命 (转录) 进行微调,所以咱们要将 forced_decoder_ids
设置为 None
。别的,模型还按捺了一些词元 (suppress_tokens
),这些词元的对数概率被强置为 -inf
,以保证它们永久不会被采样到。咱们会用一个空列表覆盖 suppress_tokens
,即咱们不按捺任何词元:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
界说练习参数
终究一步是界说与练习相关的一切参数,下面对其间一部分参数进行了解释:
-
output_dir
: 保存模型权重的本地目录,它也会是 Hugging Face Hub 上的模型存储库称号。 -
generation_max_length
: 评价阶段,自回归生成的最大词元数。 -
save_steps
: 练习期间,每save_steps
步保存一次中心 checkpoint 并异步上传到 Hub。 -
eval_steps
: 练习期间,每eval_steps
步对中心 checkpoint 进行一次评价。 -
report_to
: 练习日志的保存方位,支撑azure_ml
、comet_ml
、mlflow
、neptune
、tensorboard
以及wandb
这些渠道。你能够依照自己的偏好进行挑选,也能够直接运用缺省的tensorboard
保存至 Hub。
如需更多其他练习参数的详细信息,请参阅 Seq2SeqTrainingArguments 文档。
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-hi", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=4000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
留意: 假如不想将模型 checkpoint 上传到 Hub,你需求设置 push_to_hub=False
。
咱们能够将练习参数以及模型、数据集、数据收拾器和 compute_metrics
函数一同传给 Trainer:
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
有了这些,就能够开端练习了!
练习
要启动练习,只需履行:
trainer.train()
练习大约需求 5-10 个小时,详细取决于你的 GPU 或 Google Colab 后端的 GPU。依据 GPU 的情况,你或许会在开端练习时遇到 CUDA 内存耗尽
错误。此时,你能够将 per_device_train_batch_size
逐次削减 2 倍,同时添加 gradient_accumulation_steps
进行补偿。
打印输出:
步数 | 练习损失 | 轮数 | 验证损失 | WER |
---|---|---|---|---|
1000 | 0.1011 | 2.44 | 0.3075 | 34.63 |
2000 | 0.0264 | 4.89 | 0.3558 | 33.13 |
3000 | 0.0025 | 7.33 | 0.4214 | 32.59 |
4000 | 0.0006 | 9.78 | 0.4519 | 32.01 |
5000 | 0.0002 | 12.22 | 0.4679 | 32.10 |
最佳 WER 是 32.0% —— 对 8 小时的练习数据来说还不错!那与其他 ASR 体系相比,这个体现到底处于什么水平?为此,咱们能够检查 hf-speech-bench
,这是一个按语种和数据集对模型分别进行 WER 排名的排行榜。
微调后的模型明显提高了 Whisper small
checkpoint 的零样本性能,也突出展现了 Whisper 强壮的迁移学习才能。
当将练习成果推送到 Hub 时,只需配置适当的关键字参数 (key-word arguments,kwargs) 就能够自动将 checkpoint 提交到排行榜。如需适配自己的数据集、语种和模型称号,仅需对下述代码作出相应的修改即可:
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0", # a 'pretty' name for the training dataset
"dataset_args": "config: hi, split: test",
"language": "hi",
"model_name": "Whisper Small Hi - Sanchit Gandhi", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "hf-asr-leaderboard",
}
现在,只需履行 push_to_hub
指令就能够将练习成果上传到 Hub 了:
trainer.push_to_hub(**kwargs)
任何人能够用你的模型的 Hub 链接拜访它。他们还能够运用标识符 "your-username/the-name-you-picked"
加载它,例如:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
model = WhisperForConditionalGeneration.from_pretrained("sanchit-gandhi/whisper-small-hi")
processor = WhisperProcessor.from_pretrained("sanchit-gandhi/whisper-small-hi")
虽然微调后的模型在 Common Voice Hindi 测验数据上的作用还不错,但其作用远算不上最优。本文的目的仅为演示怎么在恣意多语种 ASR 数据集上微调预练习的 Whisper checkpoint,对作用并未做太多深究。如需提升作用,你还能够测验更多技巧,如优化练习超参 (例如 learning rate 和 dropout) 、运用更大的预练习 checkpoint ( medium
或 large
) 等。
构建演示运用
现在模型现已微调结束,咱们开端构建一个演示运用来展现其 ASR 功用!咱们将运用 Transformers pipeline
来完结整个 ASR 流水线: 从对音频输入进行预处理一直到对模型输出进行解码。咱们运用 Gradio 来构建咱们的交互式演示。 Gradio 供给了最直截了当的构建机器学习演示运用的办法,咱们能够用它在几分钟内构建一个演示运用!
运转以下代码会生成一个 Gradio 演示运用,它用核算机的麦克风录制语音并将其馈送给微调后的 Whisper 模型以转录出相应的文本:
from transformers import pipeline
import gradio as gr
pipe = pipeline(model="sanchit-gandhi/whisper-small-hi") # change to "your-username/the-name-you-picked"
def transcribe(audio):
text = pipe(audio)["text"]
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs="text",
title="Whisper Small Hindi",
description="Realtime demo for Hindi speech recognition using a fine-tuned Whisper small model.",
)
iface.launch()
结束语
经过本文,咱们介绍了怎么运用 Datasets、Transformers 和 Hugging Face Hub 一步步为多语种 ASR 微调一个 Whisper 模型。假如你想自己测验微调一个,请参阅 Google Colab。假如你有爱好针对英语和多语种 ASR 微调一个其它的 Transformers 模型,请必须参阅下 examples/pytorch/speech-recognition。
英文原文: hf.co/blog/fine-t…
原文作者: Sanchit Gandhi
译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,作业方向为 transformer-family 模型在各模态数据上的运用及大规模模型的练习推理。
审校/排版: zhongdongy (阿东)