用 RLHF 练习、微调大模型,练习自己的gpt4(三):人类反应强化学习(RLHF)

大模型的微调主要有以下几个方面:

  • 有监督的微调 (Supervised Fine-tuning,SFT)。
  • 奖赏 / 偏好建模 (Reward / preference modeling,RM)。
  • 根据人类反应的强化学习 (RLHF)。

相关的代码能够在github上访问:github.com/night-is-yo…

本文主要实现了4种模型:

  1. baichuan
  2. chatglm3
  3. qwen
  4. yi

本文主要是介绍第三部分,

RLHF练习官方的例子:github.com/huggingface…

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if epoch >= config.total_ppo_epochs:
        break
    question_tensors = batch["input_ids"]
    response_tensors = ppo_trainer.generate(
        question_tensors,
        return_prompt=False,
        length_sampler=output_length_sampler,
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
    # Compute reward score (using the sentiment analysis pipeline)
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
    # Run PPO step
    stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
    if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
        ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")

PPOTrainer源码解读

前言

RLHF分为4个模型,即:

  • Actor Model:艺人模型,这便是咱们想要练习的目标言语模型
  • Critic Model:评论家模型,它的作用是预估总收益
  • Reward Model:奖赏模型,它的作用是核算即时收益
  • Reference Model:参阅模型,它的作用是在RLHF阶段给言语模型增加一些“束缚”,防止言语模型训歪(朝不受操控的方向更新,效果可能越来越差)

其间:

  • Actor/Critic Model在RLHF阶段是需求练习的(图中给这两个模型加了粗边,便是表明这个含义);而Reward/Reference Model参数冻住的。
  • Critic/Reward/Reference Model一起组成了一个“奖赏-loss”核算系统(我自己命名的,为了便利理解),咱们归纳它们的结果核算loss,用于更新Actor和Critic Model

首先看PPOTrainer的初始化

# 初始化模型
model = AutoModelForCausalLMWithValueHead.from_pretrained(
    config.model_name,
    load_in_8bit=True,
    device_map={"": current_device},
    peft_config=lora_config,
)
sentiment_pipe = pipeline(
    "sentiment-analysis",
    model=reward_model_name,
    device_map={"": current_device},
    model_kwargs={"load_in_8bit": True},
    tokenizer=tokenizer,
    return_token_type_ids=False,
)
def __init__():
    self.optional_peft_ctx = (
                self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
                if self.is_peft_model
                else nullcontext
            )

sentiment_pipe是Reward Model:奖赏模型

self.optional_peft_ctx是Reference Model:参阅模型

模型初始化,使用了AutoModelForCausalLMWithValueHead这个模型,下面是模型的代码

class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
    transformers_parent_class = AutoModelForCausalLM
    lm_head_namings = ["lm_head", "embed_out"]
    supported_args = (
        "summary_dropout_prob",
        "v_head_initializer_range",
        "v_head_init_strategy",
    )
    def __init__(self, pretrained_model, **kwargs):
        super().__init__(pretrained_model, **kwargs)
        v_head_kwargs, _, _ = self._split_kwargs(kwargs)
        if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
            raise ValueError("The model does not have a language model head, please use a model that has one.")
        self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
        self._init_weights(**v_head_kwargs)
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        **kwargs,
    ):
        kwargs["output_hidden_states"] = True  # this had already been set in the LORA / PEFT examples
        kwargs["past_key_values"] = past_key_values
        if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
            kwargs.pop("past_key_values")
        base_model_output = self.pretrained_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )
        last_hidden_state = base_model_output.hidden_states[-1]
        lm_logits = base_model_output.logits
        loss = base_model_output.loss
        if last_hidden_state.device != self.v_head.summary.weight.device:
            last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
        value = self.v_head(last_hidden_state).squeeze(-1)
        # force upcast in fp32 if logits are in half-precision
        if lm_logits.dtype != torch.float32:
            lm_logits = lm_logits.float()
        return (lm_logits, loss, value)

AutoModelForCausalLMWithValueHead模型将大模型分成了两部分:

  1. lm_head
  2. v_head

其间 lm_head 代表Actor Model:艺人模型,v_head代表Critic Model:评论家模型

至此,4个模型均初始化完结,其间Actor、Critic、Reference均是根据同一个模型,不会重复占用显存。

RLHF中心流程:

为了便利理解RLHF的过程,这里精简了部分代码,整个办法的中心在于step办法

@PPODecorators.empty_device_cache()
def step(
    self,
    queries: List[torch.LongTensor],
    responses: List[torch.LongTensor],
    scores: List[torch.FloatTensor],
    response_masks: Optional[List[torch.LongTensor]] = None,
):
    scores = torch.tensor(scores, device=self.current_device)
    model_inputs = self.prepare_model_inputs(queries, responses)
    model_inputs_names = list(model_inputs.keys())
    full_kl_penalty = self.config.kl_penalty == "full"
    with torch.no_grad():
        all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
            self.model,
            queries,
            responses,
            model_inputs,
            response_masks=response_masks,
            return_logits=full_kl_penalty,
        )
        with self.optional_peft_ctx():
            ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
                self.model if self.is_peft_model else self.ref_model,
                queries,
                responses,
                model_inputs,
                return_logits=full_kl_penalty,
            )
    with torch.no_grad():
        rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
        values, advantages, returns = self.compute_advantages(values, rewards, masks)
    # upcast to float32 to avoid dataset issues
    batch_dict = {
        "queries": queries,
        "responses": responses,
        "logprobs": all_logprobs.to(torch.float32),
        "values": values.to(torch.float32),
        "masks": masks,
        "advantages": advantages,
        "returns": returns,
    }
    batch_dict.update(model_inputs)
    t = time.time()
    all_stats = []
    early_stop = False
    for _ in range(self.config.ppo_epochs):
        for backward_batch_start in range(0, bs, self.config.backward_batch_size):
            for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
                with self.accelerator.accumulate(self.model):
                    model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
                    logprobs, logits, vpreds, _ = self.batched_forward_pass(
                        self.model,
                        mini_batch_dict["queries"],
                        mini_batch_dict["responses"],
                        model_inputs,
                        return_logits=True,
                    )
                    train_stats = self.train_minibatch(
                        mini_batch_dict["logprobs"],
                        mini_batch_dict["values"],
                        logprobs,
                        logits,
                        vpreds,
                        mini_batch_dict["masks"],
                        mini_batch_dict["advantages"],
                        mini_batch_dict["returns"],
                    )
                    all_stats.append(train_stats)
    if self.lr_scheduler is not None:
        self.lr_scheduler.step()

程序流程如下:

在调用step()办法前,过程如下:

  1. 预备一个batch的prompts
  2. 让Actor模型生成response
  3. 把prompt+responses喂给咱们的Reward模型,生成score参数

接着调用ste() 办法,咱们传入了以下参数

  • question_tensors
  • response_tensors
  • rewards

接着调用step()办法

  1. 批量核算模型的all_logprobs, logits_or_none, values, masks

    all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
    
  2. 核算ref model 的ref_logprobs, ref_logits_or_none

    ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
    
  3. 核算奖赏

    rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
    
  4. 核算优势

    values, advantages, returns = self.compute_advantages(values, rewards, masks)
    
  5. 核算损失,梯度更新

    for _ in range(self.config.ppo_epochs):
        b_inds = np.random.permutation(bs)
        for backward_batch_start in range(0, bs, self.config.backward_batch_size):
            for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
                for k in model_inputs_names:
                    mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
                with self.accelerator.accumulate(self.model):
                    model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
                    logprobs, logits, vpreds, _ = self.batched_forward_pass(
                        self.model,
                        mini_batch_dict["queries"],
                        mini_batch_dict["responses"],
                        model_inputs,
                        return_logits=True,
                    )
                    train_stats = self.train_minibatch(
                        mini_batch_dict["logprobs"],
                        mini_batch_dict["values"],
                        logprobs,
                        logits,
                        vpreds,
                        mini_batch_dict["masks"],
                        mini_batch_dict["advantages"],
                        mini_batch_dict["returns"],
                    )
    

    上面代码总共有三个for循环:

    1. 第一个for循环表明更新的epochs数量
    2. 第二个for循环表明每次更新的批大小
    3. 第三个for循环表明每次更新的最小批大小

    经过3个for循环,将显存极限节约了

    每次更新self.config.ppo_epochs的原因是,mini_batch_dict[“logprobs”], mini_batch_dict[“values”], mini_batch_dict[“masks”],
    mini_batch_dict[“advantages”]是重复利用的,能够加快练习速度

模型批量核算源码

@PPODecorators.empty_device_cache()
def batched_forward_pass(
    self,
    model: PreTrainedModelWrapper,
    queries: torch.Tensor,
    responses: torch.Tensor,
    model_inputs: dict,
    return_logits: bool = False,
    response_masks: Optional[torch.Tensor] = None,
):
    """
    Calculate model outputs in multiple batches.
    Args:
        queries (`torch.LongTensor`):
            List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
        responses (`torch.LongTensor`):
            List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
        return_logits (`bool`, *optional*, defaults to `False`):
            Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
    Returns:
        (tuple):
            - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
                shape (`batch_size`, `response_length`)
            - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
                shape (`batch_size`, `response_length`)
            - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
    """
    bs = len(queries)
    fbs = self.config.mini_batch_size
    all_logprobs = []
    all_logits = []
    all_masks = []
    all_values = []
    model.eval()
    for i in range(math.ceil(bs / fbs)):
        input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
        query_batch = queries[i * fbs : (i + 1) * fbs]
        response_batch = responses[i * fbs : (i + 1) * fbs]
        if response_masks is not None:
            response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
        logits, _, values = model(**input_kwargs)
        if self.is_encoder_decoder:
            input_ids = input_kwargs["decoder_input_ids"]
            attention_mask = input_kwargs["decoder_attention_mask"]
        else:
            input_ids = input_kwargs["input_ids"]
            attention_mask = input_kwargs["attention_mask"]
        logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
        masks = torch.zeros_like(attention_mask)
        masks[:, :-1] = attention_mask[:, 1:]
        for j in range(len(query_batch)):
            if self.is_encoder_decoder:
                # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
                start = 1
                end = attention_mask[j, :].sum() - 1
            else:
                start = len(query_batch[j]) - 1  # logprobs starts from the second query token
                if attention_mask[j, 0] == 0:  # offset left padding
                    start += attention_mask[j, :].nonzero()[0]
                end = start + len(response_batch[j])
                if response_masks is not None:
                    response_masks_batch[j] = torch.cat(
                        (torch.zeros_like(query_batch[j]), response_masks_batch[j])
                    )[1:]
            masks[j, :start] = 0
            masks[j, end:] = 0
            if response_masks is not None:
                masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
        if return_logits:
            all_logits.append(logits)
        else:
            del logits
        all_values.append(values)
        all_logprobs.append(logprobs)
        all_masks.append(masks)
    return (
        torch.cat(all_logprobs),
        torch.cat(all_logits)[:, :-1] if return_logits else None,
        torch.cat(all_values)[:, :-1],
        torch.cat(all_masks)[:, :-1],
    )

kv散度源码

class AdaptiveKLController:
    def __init__(self, init_kl_coef, target, horizon):
        self.value = init_kl_coef
        self.target = target
        self.horizon = horizon
    def update(self, current, n_steps):
        target = self.target
        proportional_error = np.clip(current / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult

模型response的token的log_probs

def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
    """
    See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
    """
    logp = F.log_softmax(logits, dim=2)
    if not gather:
        return logp
    logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logpy

核算奖赏源码

    def compute_rewards(
        self,
        scores: torch.FloatTensor,
        logprobs: torch.FloatTensor,
        ref_logprobs: torch.FloatTensor,
        masks: torch.LongTensor,
    ):
        """
        Compute per token rewards from scores and KL-penalty.
        Args:
            scores (`torch.FloatTensor`):
                Scores from the reward model, shape (`batch_size`)
            logprobs (`torch.FloatTensor`):
                Log probabilities of the model, shape (`batch_size`, `response_length`)
            ref_logprobs (`torch.FloatTensor`):
                Log probabilities of the reference model, shape (`batch_size`, `response_length`)
        Returns:
            `torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
            `torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
            `torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
        """
        rewards, non_score_rewards, kls = [], [], []
        for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
            # compute KL penalty (from difference in logprobs)
            kl = self._kl_penalty(logprob, ref_logprob)
            kls.append(kl)
            non_score_reward = -self.kl_ctl.value * kl
            non_score_rewards.append(non_score_reward)
            reward = non_score_reward.clone()
            last_non_masked_index = mask.nonzero()[-1]
            # reward is preference model score + KL penalty
            reward[last_non_masked_index] += score
            rewards.append(reward)
        return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)

核算优势源码

def compute_advantages(
    self,
    values: torch.FloatTensor,
    rewards: torch.FloatTensor,
    mask: torch.FloatTensor,
):
    lastgaelam = 0
    advantages_reversed = []
    gen_len = rewards.shape[-1]
    values = values * mask
    rewards = rewards * mask
    if self.config.whiten_rewards:
        rewards = masked_whiten(rewards, mask, shift_mean=False)
    for t in reversed(range(gen_len)):
        nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
        delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
        lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
    returns = advantages + values
    advantages = masked_whiten(advantages, mask)
    advantages = advantages.detach()
    return values, advantages, returns

核算损失源码

def loss(
    self,
    old_logprobs: torch.FloatTensor,
    values: torch.FloatTensor,
    logits: torch.FloatTensor,
    vpreds: torch.FloatTensor,
    logprobs: torch.FloatTensor,
    mask: torch.LongTensor,
    advantages: torch.FloatTensor,
    returns: torch.FloatTensor,
):
    """
    Calculate policy and value losses.
    Args:
        old_logprobs (`torch.FloatTensor`):
            Log probabilities of the model, shape (`batch_size`, `response_length`)
        values (`torch.FloatTensor`):
            Values of the value head, shape (`batch_size`, `response_length`)
        rewards (`torch.FloatTensor`):
            Rewards from the reward model, shape (`batch_size`, `response_length`)
        logits (`torch.FloatTensor`):
            Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`)
        v_pred (`torch.FloatTensor`):
            Values of the value head, shape (`batch_size`, `response_length`)
        logprobs (`torch.FloatTensor`):
            Log probabilities of the model, shape (`batch_size`, `response_length`)
    """
    vpredclipped = clip_by_value(
        vpreds,
        values - self.config.cliprange_value,
        values + self.config.cliprange_value,
    )
    vf_losses1 = (vpreds - returns) ** 2
    vf_losses2 = (vpredclipped - returns) ** 2
    vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
    vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
    ratio = torch.exp(logprobs - old_logprobs)
    pg_losses = -advantages * ratio
    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
    pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
    pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
    loss = pg_loss + self.config.vf_coef * vf_loss
    avg_ratio = masked_mean(ratio, mask).item()
    if avg_ratio > self.config.ratio_threshold:
        warnings.warn(
            f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch."
        )
        pg_loss = pg_loss * 0.0
        vf_loss = vf_loss * 0.0
        loss = loss * 0.0
    entropy = masked_mean(entropy_from_logits(logits), mask)
    approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
    policykl = masked_mean(old_logprobs - logprobs, mask)
    return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
    value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
    stats = dict(
        loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
        policy=dict(
            entropy=entropy.detach(),
            approxkl=approxkl.detach(),
            policykl=policykl.detach(),
            clipfrac=pg_clipfrac.detach(),
            advantages=advantages.detach(),
            advantages_mean=masked_mean(advantages, mask).detach(),
            ratio=ratio.detach(),
        ),
        returns=dict(mean=return_mean.detach(), var=return_var.detach()),
        val=dict(
            vpred=masked_mean(vpreds, mask).detach(),
            error=masked_mean((vpreds - returns) ** 2, mask).detach(),
            clipfrac=vf_clipfrac.detach(),
            mean=value_mean.detach(),
            var=value_var.detach(),
        ),
    )
    return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)

梯度更新源码

@PPODecorators.empty_device_cache()
def train_minibatch(
    self,
    old_logprobs: torch.FloatTensor,
    values: torch.FloatTensor,
    logprobs: torch.FloatTensor,
    logits: torch.FloatTensor,
    vpreds: torch.FloatTensor,
    mask: torch.LongTensor,
    advantages: torch.FloatTensor,
    returns: torch.FloatTensor,
):
    """
    Train one PPO minibatch
    Args:
        logprobs (`torch.FloatTensor`):
            Log probabilities of the model, shape [mini_batch_size, response_length]
        values (`torch.FloatTensor`):
            Values of the value head, shape [mini_batch_size, response_length]
        query (`torch.LongTensor`):
            Encoded queries, shape [mini_batch_size, query_length]
        response (`torch.LongTensor`):
            Encoded responses, shape [mini_batch_size, response_length]
        model_input (`torch.LongTensor`):
            Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
    Returns:
        train_stats (dict[str, `torch.Tensor`]):
            Dictionary of training statistics
    """
    self.model.train()
    loss_p, loss_v, train_stats = self.loss(
        old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
    )
    loss = loss_p + loss_v
    self.accelerator.backward(loss)
    if self.config.max_grad_norm is not None:
        if self.accelerator.sync_gradients:
            self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
    self.optimizer.step()
    # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
    # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
    self.optimizer.zero_grad()
    return train_stats