您现在的位置是:首页 > 文章详情

谷歌开源 TensorFlow 的简化库 JAX

日期:2018-12-15点击:487

谷歌开源了一个 TensorFlow 的简化库 JAX。


JAX 结合了 Autograd 和 XLA,专门用于高性能机器学习研究。

凭借 Autograd,JAX 可以求导循环、分支、递归和闭包函数,并且它可以进行三阶求导。通过 grad,它支持自动模式反向求导(反向传播)和正向求导,且二者可以任何顺序任意组合。

得力于 XLA,可以在 GPU 和 TPU 上编译和运行 NumPy 程序。默认情况下,编译发生在底层,库调用实时编译和执行。但是 JAX 还允许使用单一函数 API jit 将 Python 函数及时编译为 XLA 优化的内核。编译和自动求导可以任意组合,因此可以在 Python 环境下实现复杂的算法并获得最大的性能。

demo:

import jax.numpy as np from jax import grad, jit, vmap from functools import partial def predict(params, inputs):   for W, b in params:     outputs = np.dot(inputs, W) + b     inputs = np.tanh(outputs)   return outputs def logprob_fun(params, inputs, targets):   preds = predict(params, inputs)   return np.sum((preds - targets)**2) grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

更深入地看,JAX 实际上是一个可扩展的可组合函数转换系统,grad 和 jit 都是这种转换的实例。

项目地址:https://github.com/google/JAX

原文链接:https://www.oschina.net/news/102714/google-opensource-jax
关注公众号

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。

持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。

转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。

文章评论

共有0条评论来说两句吧...

文章二维码

扫描即可查看该文章

点击排行

推荐阅读

最新文章