不断发展的 JAX:加速 AI 研究的利器

JAX是一个由Google Research开发的用于高性能数字计算,特别是机器学习的Python库。它基于NumPy构建,提供自动微分、向量化和JIT编译等功能,助力DeepMind的AI研究。DeepMind已创建一系列基于JAX的库,如Haiku(神经网络)、Optax(优化)、RLax(强化学习)和Chex(测试工具),以支持其研究社区的快速发展。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

文 / David Budden 与 Matteo Hessel

DeepMind 工程师通过构建工具、对算法进行拓展和创造具有挑战性的虚拟和物理环境来训练和测试人工智能 (AI) 系统,加速我们的研究。作为这项工作的一部分,我们在持续评估机器学习新的库和框架。

近来,我们发现由 Google Research 团队开发的机器学习框架 JAX 为越来越多的项目提供良好支持。JAX 与我们的工程理念产生了很好的共鸣,并在去年被我们的研究社区广泛使用。本文将分享我们的 JAX 使用经验,来说明我们认为它有助于我们 AI 研究的原因,并概述我们正在为支持各地研究人员而建立的生态系统。

  • Google Research
    https://research.google/

  • JAX
    https://github.com/google/jax#jax-autograd-and-xla-

为什么选择 JAX?

JAX 是为高性能数字计算(尤其是机器学习研究)而设计的 Python 库。其用于数值计算的 API 基于 NumPy 这样一个用于科学计算的函数库所构建。得益于 Python 和 NumPy 较高的使用率和知名度,使得 JAX 简洁灵活、易于使用。

  • NumPy
    https://www.nature.com/articles/s41586-020-2649-2

除了其 NumPy API 之外,JAX 还具有一个用于可组合函数的转换的扩展系统,在以下几方面帮助机器学习研究:

  • 微分:梯度优化是 ML 的基础。通过 grad、hessian、jacfwd 和 jacrev 等方法实现了函数转换,JAX 为任意数值函数的正向和反向自动微分提供了原生支持。

  • 自动微分
    https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

  • 向量化:在 ML 研究中,我们经常将一个函数应用于大量数据中,例如计算一个批次数据的损失,或在微分独立学习时评估每个样本的梯度。JAX 通过 vmap 转换实现自动向量化,简化了这种形式的编程。又例如,研究人员在实现新算法时,无需推理批处理。JAX 还提供相关 pmap 转换来支持大规模数据并行,在数据过大时精妙地分配单个加速器内存。

  • 评估每个样本的梯度
    https://arxiv.org/abs/2010.09063

  • JIT 编译:XLA 被用于在 GPU 和 Cloud TPU 加速器上进行及时 (JIT) 编译和执行 JAX 程序。JIT 编译结合 JAX 中与 NumPy 一致的 API,使没有高性能计算经验的研究人员也可以轻松扩展研究至一个或多个加速器上。

  • XLA
    https://tensorflow.google.cn/xla

登录后您可以享受以下权益:

×
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值

举报

选择你想要举报的内容(必选)
  • 内容涉黄
  • 政治相关
  • 内容抄袭
  • 涉嫌广告
  • 内容侵权
  • 侮辱谩骂
  • 样式问题
  • 其他
程序员都在用的中文IT技术交流社区

程序员都在用的中文IT技术交流社区

专业的中文 IT 技术社区,与千万技术人共成长

专业的中文 IT 技术社区,与千万技术人共成长

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

客服 返回顶部