PWM: Policy Learning with Large World Models

NeurIPS 2024 Submission

Authors hidden while paper is under review.

Abstract

Reinforcement Learning (RL) has achieved impressive results on complex tasks but struggles in multi-task settings with different embodiments. World models offer scalability by learning a simulation of the environment, yet they often rely on inefficient gradient-free optimization methods. We introduce Policy learning with large World Models (PWM), a novel model-based RL algorithm that learns continuous control policies from large multi-task world models. By pre-training the world model on offline data and using it for first-order gradient policy learning, PWM effectively solves tasks with up to 152 action dimensions and outperforms methods using ground-truth dynamics. Additionally, PWM scales to an 80-task setting, achieving up to 27% higher rewards than existing baselines without the need for expensive online planning.

Video

Method overview


Instead of building world models into algorithms, we propose using large-scale multi-task world models as differentiable simulators for policy learning. When well-regularized, these models enable efficient policy learning with first-order gradient optimization. This allows PWM to learn to solve 80 tasks in < 10 minutes each without the need for expensive online planning.
PWM teaser results
We evaluate PWM on high-dimensional continuous control tasks (left figure) and find that it not only outperforms model-free baselines SAC and PPO but also achieves higher rewards than SHAC, a method using the dynamics and reward function of the simulator directly. In an 80-task setting (right figure) using a large 48M-parameter world model, PWM is able to consistently outperform TDMPC2, an MBRL method that uses the same world model but plans for actions online.

Single-task results

agg results

The figure shows 50% IQM with solid lines, mean with dashed lines, and 95% CI over all 5 tasks and 5 random seeds. PWM is able to achieve a higher reward than model-free baselines PPO and SAC, TDMPC2, which uses the same world model as PWM and SHAC which uses the ground-truth dynamics and reward functions of the simulator. These results indicate that well-regularized world models can smooth out the optimization landscape, allowing for better first-order gradient optimization.

Multi-task results

Full multi-task results

The figure shows the performance of PWM and TDMPC2 on 30 and 80 multi-task benchmarks with results over 10 random seeds. PWM is able to outperform TDMPC2 while using the same world model without any form of online planning, making it the more scalable approach to large world models. The right figure compares PWM, a multi-task policy, with single-task experts SAC and DreamerV3. It is impressive that PWM is able to match their performance while being multi-task and only trained on offline data.

We strongly believe that increasingly large world models of billions of parameters and policy learning strategies such as PWM can unlock dexterous robotics at scale.