PWM: Policy Learning with Multi-Task World Models

ICLR 2024 Submission

Authors hidden while paper is under review.

Abstract

Reinforcement Learning (RL) has made significant strides in complex tasks but struggles in multi-task settings with different embodiments. World models methods offer scalability by learning a simulation of the environment, but often rely on inefficient gradient-free optimization methods for policy extraction. In contrast, gradient-based methods exhibit lower variance but fail to handle discontinuities. Our work reveals that well-regularized world models can generate smoother optimization landscapes than the actual dynamics, facilitating more effective first-order optimization. We introduce Policy learning with multi-task World Models (PWM), a novel model-based RL algorithm for continuous control. Initially, the world model is pre-trained on offline data, and then policies are extracted from it using first-order optimization in less than 10 minutes per task. PWM effectively solves tasks with up to 152 action dimensions and outperforms methods using ground-truth dynamics. Additionally, PWM scales to an 80-task setting, achieving up to 27% higher rewards than existing baselines, without relying on costly online planning.

Video

Method overview


Instead of building world models into algorithms, we propose using large-scale multi-task world models as differentiable simulators for policy learning. When well-regularized, these models enable efficient policy learning with first-order gradient optimization. This allows PWM to learn to solve 80 tasks in < 10 minutes each without the need for expensive online planning.
PWM teaser results
We evaluate PWM on high-dimensional continuous control tasks (left figure) and find that it not only outperforms model-free baselines SAC and PPO but also achieves higher rewards than SHAC, a method using the dynamics and reward function of the simulator directly. In an 80-task setting (right figure) using a large 48M-parameter world model, PWM is able to consistently outperform TDMPC2, an MBRL method that uses the same world model but plans for actions online.

Single-task results

agg results

The figure shows 50% IQM with solid lines, mean with dashed lines, and 95% CI over all 5 tasks and 5 random seeds. PWM is able to achieve a higher reward than model-free baselines PPO and SAC, TDMPC2, which uses the same world model as PWM and SHAC which uses the ground-truth dynamics and reward functions of the simulator. These results indicate that well-regularized world models can smooth out the optimization landscape, allowing for better first-order gradient optimization.

Multi-task results

Full multi-task results

The figure shows the performance of PWM and TDMPC2 on 30 and 80 multi-task benchmarks with results over 10 random seeds. PWM is able to outperform TDMPC2 while using the same world model without any form of online planning, making it the more scalable approach to large world models. The right figure compares PWM, a multi-task policy, with single-task experts SAC and DreamerV3. It is impressive that PWM is able to match their performance while being multi-task and only trained on offline data.

We strongly believe that increasingly large world models of billions of parameters and policy learning strategies such as PWM can unlock dexterous robotics at scale.