用 RLHF 练习、微调大模型,练习自己的gpt4(三):人类反应强化学习(RLHF)
大模型的微调主要有以下几个方面:
- 有监督的微调 (Supervised Fine-tuning,SFT)。
- 奖赏 / 偏好建模 (Reward / preference modeling,RM)。
- 根据人类反应的强化学习 (RLHF)。
相关的代码能够在github上访问:github.com/night-is-yo…
本文主要实现了4种模型:
- baichuan
- chatglm3
- qwen
- yi
本文主要是介绍第三部分,
RLHF练习官方的例子:github.com/huggingface…
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
if epoch >= config.total_ppo_epochs:
break
question_tensors = batch["input_ids"]
response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
length_sampler=output_length_sampler,
**generation_kwargs,
)
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
# Compute reward score (using the sentiment analysis pipeline)
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
# Run PPO step
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")
PPOTrainer源码解读
前言
RLHF分为4个模型,即:
- Actor Model:艺人模型,这便是咱们想要练习的目标言语模型
- Critic Model:评论家模型,它的作用是预估总收益
- Reward Model:奖赏模型,它的作用是核算即时收益
- Reference Model:参阅模型,它的作用是在RLHF阶段给言语模型增加一些“束缚”,防止言语模型训歪(朝不受操控的方向更新,效果可能越来越差)
其间:
- Actor/Critic Model在RLHF阶段是需求练习的(图中给这两个模型加了粗边,便是表明这个含义);而Reward/Reference Model是参数冻住的。
- Critic/Reward/Reference Model一起组成了一个“奖赏-loss”核算系统(我自己命名的,为了便利理解),咱们归纳它们的结果核算loss,用于更新Actor和Critic Model
首先看PPOTrainer的初始化
# 初始化模型
model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
load_in_8bit=True,
device_map={"": current_device},
peft_config=lora_config,
)
sentiment_pipe = pipeline(
"sentiment-analysis",
model=reward_model_name,
device_map={"": current_device},
model_kwargs={"load_in_8bit": True},
tokenizer=tokenizer,
return_token_type_ids=False,
)
def __init__():
self.optional_peft_ctx = (
self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
if self.is_peft_model
else nullcontext
)
sentiment_pipe是Reward Model:奖赏模型
self.optional_peft_ctx是Reference Model:参阅模型
模型初始化,使用了AutoModelForCausalLMWithValueHead这个模型,下面是模型的代码
class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
transformers_parent_class = AutoModelForCausalLM
lm_head_namings = ["lm_head", "embed_out"]
supported_args = (
"summary_dropout_prob",
"v_head_initializer_range",
"v_head_init_strategy",
)
def __init__(self, pretrained_model, **kwargs):
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
raise ValueError("The model does not have a language model head, please use a model that has one.")
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
self._init_weights(**v_head_kwargs)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
**kwargs,
):
kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples
kwargs["past_key_values"] = past_key_values
if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
kwargs.pop("past_key_values")
base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
last_hidden_state = base_model_output.hidden_states[-1]
lm_logits = base_model_output.logits
loss = base_model_output.loss
if last_hidden_state.device != self.v_head.summary.weight.device:
last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
value = self.v_head(last_hidden_state).squeeze(-1)
# force upcast in fp32 if logits are in half-precision
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()
return (lm_logits, loss, value)
AutoModelForCausalLMWithValueHead模型将大模型分成了两部分:
- lm_head
- v_head
其间 lm_head 代表Actor Model:艺人模型,v_head代表Critic Model:评论家模型
至此,4个模型均初始化完结,其间Actor、Critic、Reference均是根据同一个模型,不会重复占用显存。
RLHF中心流程:
为了便利理解RLHF的过程,这里精简了部分代码,整个办法的中心在于step办法
@PPODecorators.empty_device_cache()
def step(
self,
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
response_masks: Optional[List[torch.LongTensor]] = None,
):
scores = torch.tensor(scores, device=self.current_device)
model_inputs = self.prepare_model_inputs(queries, responses)
model_inputs_names = list(model_inputs.keys())
full_kl_penalty = self.config.kl_penalty == "full"
with torch.no_grad():
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
self.model,
queries,
responses,
model_inputs,
response_masks=response_masks,
return_logits=full_kl_penalty,
)
with self.optional_peft_ctx():
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
self.model if self.is_peft_model else self.ref_model,
queries,
responses,
model_inputs,
return_logits=full_kl_penalty,
)
with torch.no_grad():
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
values, advantages, returns = self.compute_advantages(values, rewards, masks)
# upcast to float32 to avoid dataset issues
batch_dict = {
"queries": queries,
"responses": responses,
"logprobs": all_logprobs.to(torch.float32),
"values": values.to(torch.float32),
"masks": masks,
"advantages": advantages,
"returns": returns,
}
batch_dict.update(model_inputs)
t = time.time()
all_stats = []
early_stop = False
for _ in range(self.config.ppo_epochs):
for backward_batch_start in range(0, bs, self.config.backward_batch_size):
for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
with self.accelerator.accumulate(self.model):
model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
logprobs, logits, vpreds, _ = self.batched_forward_pass(
self.model,
mini_batch_dict["queries"],
mini_batch_dict["responses"],
model_inputs,
return_logits=True,
)
train_stats = self.train_minibatch(
mini_batch_dict["logprobs"],
mini_batch_dict["values"],
logprobs,
logits,
vpreds,
mini_batch_dict["masks"],
mini_batch_dict["advantages"],
mini_batch_dict["returns"],
)
all_stats.append(train_stats)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
程序流程如下:
在调用step()办法前,过程如下:
- 预备一个batch的prompts
- 让Actor模型生成response
- 把prompt+responses喂给咱们的Reward模型,生成score参数
接着调用ste() 办法,咱们传入了以下参数
- question_tensors
- response_tensors
- rewards
接着调用step()办法
-
批量核算模型的all_logprobs, logits_or_none, values, masks
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
-
核算ref model 的ref_logprobs, ref_logits_or_none
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
-
核算奖赏
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
-
核算优势
values, advantages, returns = self.compute_advantages(values, rewards, masks)
-
核算损失,梯度更新
for _ in range(self.config.ppo_epochs): b_inds = np.random.permutation(bs) for backward_batch_start in range(0, bs, self.config.backward_batch_size): for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size): for k in model_inputs_names: mini_batch_dict[k] = batch_dict[k][mini_batch_inds] with self.accelerator.accumulate(self.model): model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names} logprobs, logits, vpreds, _ = self.batched_forward_pass( self.model, mini_batch_dict["queries"], mini_batch_dict["responses"], model_inputs, return_logits=True, ) train_stats = self.train_minibatch( mini_batch_dict["logprobs"], mini_batch_dict["values"], logprobs, logits, vpreds, mini_batch_dict["masks"], mini_batch_dict["advantages"], mini_batch_dict["returns"], )
上面代码总共有三个for循环:
- 第一个for循环表明更新的epochs数量
- 第二个for循环表明每次更新的批大小
- 第三个for循环表明每次更新的最小批大小
经过3个for循环,将显存极限节约了
每次更新self.config.ppo_epochs的原因是,mini_batch_dict[“logprobs”], mini_batch_dict[“values”], mini_batch_dict[“masks”],
mini_batch_dict[“advantages”]是重复利用的,能够加快练习速度
模型批量核算源码
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: PreTrainedModelWrapper,
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
"""
Calculate model outputs in multiple batches.
Args:
queries (`torch.LongTensor`):
List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
responses (`torch.LongTensor`):
List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
return_logits (`bool`, *optional*, defaults to `False`):
Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
Returns:
(tuple):
- all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
model.eval()
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
logits, _, values = model(**input_kwargs)
if self.is_encoder_decoder:
input_ids = input_kwargs["decoder_input_ids"]
attention_mask = input_kwargs["decoder_attention_mask"]
else:
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
if self.is_encoder_decoder:
# Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
start = 1
end = attention_mask[j, :].sum() - 1
else:
start = len(query_batch[j]) - 1 # logprobs starts from the second query token
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch[j] = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
kv散度源码
class AdaptiveKLController:
def __init__(self, init_kl_coef, target, horizon):
self.value = init_kl_coef
self.target = target
self.horizon = horizon
def update(self, current, n_steps):
target = self.target
proportional_error = np.clip(current / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
模型response的token的log_probs
def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
logp = F.log_softmax(logits, dim=2)
if not gather:
return logp
logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
return logpy
核算奖赏源码
def compute_rewards(
self,
scores: torch.FloatTensor,
logprobs: torch.FloatTensor,
ref_logprobs: torch.FloatTensor,
masks: torch.LongTensor,
):
"""
Compute per token rewards from scores and KL-penalty.
Args:
scores (`torch.FloatTensor`):
Scores from the reward model, shape (`batch_size`)
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape (`batch_size`, `response_length`)
ref_logprobs (`torch.FloatTensor`):
Log probabilities of the reference model, shape (`batch_size`, `response_length`)
Returns:
`torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
`torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
`torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
"""
rewards, non_score_rewards, kls = [], [], []
for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
# compute KL penalty (from difference in logprobs)
kl = self._kl_penalty(logprob, ref_logprob)
kls.append(kl)
non_score_reward = -self.kl_ctl.value * kl
non_score_rewards.append(non_score_reward)
reward = non_score_reward.clone()
last_non_masked_index = mask.nonzero()[-1]
# reward is preference model score + KL penalty
reward[last_non_masked_index] += score
rewards.append(reward)
return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)
核算优势源码
def compute_advantages(
self,
values: torch.FloatTensor,
rewards: torch.FloatTensor,
mask: torch.FloatTensor,
):
lastgaelam = 0
advantages_reversed = []
gen_len = rewards.shape[-1]
values = values * mask
rewards = rewards * mask
if self.config.whiten_rewards:
rewards = masked_whiten(rewards, mask, shift_mean=False)
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
returns = advantages + values
advantages = masked_whiten(advantages, mask)
advantages = advantages.detach()
return values, advantages, returns
核算损失源码
def loss(
self,
old_logprobs: torch.FloatTensor,
values: torch.FloatTensor,
logits: torch.FloatTensor,
vpreds: torch.FloatTensor,
logprobs: torch.FloatTensor,
mask: torch.LongTensor,
advantages: torch.FloatTensor,
returns: torch.FloatTensor,
):
"""
Calculate policy and value losses.
Args:
old_logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Values of the value head, shape (`batch_size`, `response_length`)
rewards (`torch.FloatTensor`):
Rewards from the reward model, shape (`batch_size`, `response_length`)
logits (`torch.FloatTensor`):
Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`)
v_pred (`torch.FloatTensor`):
Values of the value head, shape (`batch_size`, `response_length`)
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape (`batch_size`, `response_length`)
"""
vpredclipped = clip_by_value(
vpreds,
values - self.config.cliprange_value,
values + self.config.cliprange_value,
)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
ratio = torch.exp(logprobs - old_logprobs)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
loss = pg_loss + self.config.vf_coef * vf_loss
avg_ratio = masked_mean(ratio, mask).item()
if avg_ratio > self.config.ratio_threshold:
warnings.warn(
f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch."
)
pg_loss = pg_loss * 0.0
vf_loss = vf_loss * 0.0
loss = loss * 0.0
entropy = masked_mean(entropy_from_logits(logits), mask)
approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
policykl = masked_mean(old_logprobs - logprobs, mask)
return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
stats = dict(
loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
policy=dict(
entropy=entropy.detach(),
approxkl=approxkl.detach(),
policykl=policykl.detach(),
clipfrac=pg_clipfrac.detach(),
advantages=advantages.detach(),
advantages_mean=masked_mean(advantages, mask).detach(),
ratio=ratio.detach(),
),
returns=dict(mean=return_mean.detach(), var=return_var.detach()),
val=dict(
vpred=masked_mean(vpreds, mask).detach(),
error=masked_mean((vpreds - returns) ** 2, mask).detach(),
clipfrac=vf_clipfrac.detach(),
mean=value_mean.detach(),
var=value_var.detach(),
),
)
return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)
梯度更新源码
@PPODecorators.empty_device_cache()
def train_minibatch(
self,
old_logprobs: torch.FloatTensor,
values: torch.FloatTensor,
logprobs: torch.FloatTensor,
logits: torch.FloatTensor,
vpreds: torch.FloatTensor,
mask: torch.LongTensor,
advantages: torch.FloatTensor,
returns: torch.FloatTensor,
):
"""
Train one PPO minibatch
Args:
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape [mini_batch_size, response_length]
values (`torch.FloatTensor`):
Values of the value head, shape [mini_batch_size, response_length]
query (`torch.LongTensor`):
Encoded queries, shape [mini_batch_size, query_length]
response (`torch.LongTensor`):
Encoded responses, shape [mini_batch_size, response_length]
model_input (`torch.LongTensor`):
Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
Returns:
train_stats (dict[str, `torch.Tensor`]):
Dictionary of training statistics
"""
self.model.train()
loss_p, loss_v, train_stats = self.loss(
old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
)
loss = loss_p + loss_v
self.accelerator.backward(loss)
if self.config.max_grad_norm is not None:
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
self.optimizer.step()
# we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
# see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
self.optimizer.zero_grad()
return train_stats