文 / David Budden 与 Matteo Hessel
DeepMind 工程师通过构建工具、对算法进行拓展和创造具有挑战性的虚拟和物理环境来训练和测试人工智能 (AI) 系统,加速我们的研究。作为这项工作的一部分,我们在持续评估机器学习新的库和框架。
近来,我们发现由 Google Research 团队开发的机器学习框架 JAX 为越来越多的项目提供良好支持。JAX 与我们的工程理念产生了很好的共鸣,并在去年被我们的研究社区广泛使用。本文将分享我们的 JAX 使用经验,来说明我们认为它有助于我们 AI 研究的原因,并概述我们正在为支持各地研究人员而建立的生态系统。
Google Research
https://research.google/JAX
https://github.com/google/jax#jax-autograd-and-xla-
为什么选择 JAX?
JAX 是为高性能数字计算(尤其是机器学习研究)而设计的 Python 库。其用于数值计算的 API 基于 NumPy 这样一个用于科学计算的函数库所构建。得益于 Python 和 NumPy 较高的使用率和知名度,使得 JAX 简洁灵活、易于使用。
NumPy
https://www.nature.com/articles/s41586-020-2649-2
除了其 NumPy API 之外,JAX 还具有一个用于可组合函数的转换的扩展系统,在以下几方面帮助机器学习研究:
微分:梯度优化是 ML 的基础。通过 grad、hessian、jacfwd 和 jacrev 等方法实现了函数转换,JAX 为任意数值函数的正向和反向自动微分提供了原生支持。
自动微分
https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
向量化:在 ML 研究中,我们经常将一个函数应用于大量数据中,例如计算一个批次数据的损失,或在微分独立学习时评估每个样本的梯度。JAX 通过 vmap 转换实现自动向量化,简化了这种形式的编程。又例如,研究人员在实现新算法时,无需推理批处理。JAX 还提供相关 pmap 转换来支持大规模数据并行,在数据过大时精妙地分配单个加速器内存。
评估每个样本的梯度
https://arxiv.org/abs/2010.09063
JIT 编译:XLA 被用于在 GPU 和 Cloud TPU 加速器上进行及时 (JIT) 编译和执行 JAX 程序。JIT 编译结合 JAX 中与 NumPy 一致的 API,使没有高性能计算经验的研究人员也可以轻松扩展研究至一个或多个加速器上。
XLA
https://tensorflow.google.cn/xla