DeepSpeed-Chat 允许用户使用我们灵活的 API(如下所示)构建自己的 RLHF 训练流程,用户可以使用这些 API 重建自己的 RLHF 训练策略。我们希望这些功能可以为研究探索中创建各种 RLHF 算法提供通用接口和后端。
engine = DeepSpeedRLHFEngine(
actor_model_name_or_path=args.actor_model_name_or_path,
critic_model_name_or_path=args.critic_model_name_or_path,
tokenizer=tokenizer,
num_total_iters=num_total_iters,
args=args)
trainer = DeepSpeedPPOTrainer(engine=engine, args=args)
for prompt_batch in prompt_train_dataloader:
out = trainer.generate_experience(prompt_batch)
actor_loss, critic_loss = trainer.train_rlhf(out)
© 版权声明
文章版权归作者所有,未经允许请勿转载。