Retrieval-Augmented Reinforcement Learning

Modern deep reinforcement learning (RL) algorithms distill past experience into parametric behavior rules, such as policies or value functions, via gradient updates. While effective, this approach has several disadvantages: it is computationally expensive, it takes many updates to integrate experiences into the parametric model, experiences that are not fully integrated do not appropriately influence the agent's behavior, and this behavior is limited by the capacity of the model. In this paper we propose an alternative paradigm in which the mapping from experience to behavior is amortized into a network that is trained to map a dataset of past experience to optimal behavior. Concretely, we augment an RL agent with a retrieval process (parameterized as a neural network) that has direct access to a dataset of experiences. This dataset could come from the agent's past experiences, expert demonstrations, or any other relevant source. The retrieval process is trained to retrieve information from the dataset that may be useful in the current context, to help the agent achieve its goal faster and more efficiently. We integrate our method into two different RL agents including an offline DQN agent and an online R2D2 agent. In offline multi-task problems, we show that a retrieval-augmented DQN agent avoids task interference and learns faster than a baseline DQN agent. On Atari, we show that retrieval-augmented R2D2 learns significantly faster than the baseline R2D2 agent and achieves higher scores. Ablations demonstrate that the agent uses the available past experience during learning to make more effective updates.

Authors' notes