Myblog

the amazing thing of think

python Jax

jax

jax 是Google 开源的非官方机器学习框架,个人感觉 清爽高效。而且相对于tensorflow,pytorch 围绕张量展开计算。还是喜欢基于矩阵进行的计算。另外一点就是Jax 是jax.numpy 和 jax.scipy 是用TPU加速的。对神经网络芯片软加速有一定的借鉴意义。
内有两大神器

  • autograd /pytorch 倒是也有
  • xla /是一种针对特定领域的线性代数编译器
    另一点就是
  • Jax 重写了random算法。这个和其他的都不太一样。

Jax的生态。从Awesome Jax 翻译过来包括以下这些。但我都没用过
Flax - 一个灵活的库,拥有所有JAX NN库中最大的用户群。
Haiku–专注于简单,由DeepMind的Sonnet作者创建。
Objax - 具有类似于PyTorch的面向对象设计。
Elegy–实现了Keras API的一些改进。
RLax - 实现强化学习代理的库。
Trax–一个 “考虑使用电池 “的深度学习库,专注于为常见工作负载提供解决方案。应该是偏节能。
Jraph - 一个轻量级的图神经网络库。
NumPyro - 基于Pyro库的概率编程。
Chex - 编写和测试可靠的JAX代码的实用工具。
Optax - 一个梯度处理和优化库。
JAX, M.D. - 加速的微分分子动力学。
Coax - 将RL论文转化为代码,简单的方法。
SymJAX - 符号CPU/GPU/TPU编程。
mcx - 为执行推理表达和编译概率程序。
Parallax - JAX的不可变torch 模块原型/应该算 能TPU加速的torch。
FedJAX - JAX中的联合学习,建立在Optax和Haiku之上。
jax-unirep - 为蛋白质机器学习应用实现UniRep模型的库。
jax-flows - 在JAX中规范化流。
sklearn-jax-kernels - 基于Jax 的scikit-learn kernel.
jax-cosmo - 一个可区分的宇宙学库。
efax - JAX中的指数族。
mpi4jax - 将MPI操作与CPU和GPU上的Jax代码相结合。