开发者 Eric Zhang 近期发布了 jax-js,这是一个为 Web 平台量身定制的纯 JavaScript 机器学习框架。它的核心愿景是将 Google DeepMind 的 JAX 框架能力带入浏览器,让前端环境也能拥有高性能的数值计算和自动微分能力。| blog | Github | #机器学习 #框架

长期以来,JavaScript 在重度数值计算领域一直处于劣势,原因在于其 JIT 引擎并非为紧密的数值循环而设计,甚至缺乏原生的快速整数类型。然而,WebAssembly 和 WebGPU 的成熟改变了游戏规则。jax-js 通过生成高效的 Wasm 和 WebGPU 内核,让程序能够以接近原生的速度在浏览器中运行,彻底绕过了 JavaScript 解释器的性能瓶颈。

在编程模型上,jax-js 高度还原了 JAX 的设计哲学。它支持程序追踪与 JIT 编译,可以将开发者编写的 JS 代码即时转化为 GPU 着色器指令。虽然由于 JavaScript 语言限制,它无法像 Python 那样支持运算符重载,必须使用类似 .mul() 的方法调用,但其 API 与 NumPy 和 JAX 几乎完全一致。为了解决 JS 缺乏引用计数析构函数的问题,它还借鉴了 Rust 的所有权语义,通过 .ref 系统精细管理内存。

功能方面,jax-js 完整保留了 JAX 的精髓,包括自动微分 grad、向量化变换 vmap 以及内核融合 jit。开发者展示了一个令人印象深刻的案例:在浏览器中从零开始训练 MNIST 神经网络,仅需数秒即可达到 99% 以上的准确率。更具实践意义的是,它能实时处理 18 万字的文学巨著,通过 CLIP 嵌入模型实现毫秒级的语义搜索。

性能表现上,jax-js 在 M4 Pro 芯片上的矩阵乘法算力超过了 3 TFLOPs。在特定基准测试中,其性能甚至优于 TensorFlow.js 和 ONNX 等成熟框架。这主要归功于其编译器架构,它能够根据输入形状自动优化并生成内核,而非仅仅依赖预构建的静态库。

从技术深度来看,jax-js 将框架分为负责自动微分和追踪的前端,以及负责执行内核的后端。其自动微分实现参考了 Tinygrad 的简洁设计,通过数学上的对偶变换,让开发者在实现一阶导数规则后,能够自然地获得任意高阶导数。这种架构不仅优雅,也为未来的内核融合与优化提供了极高的灵活性。

目前 jax-js 已在 GitHub 开源。尽管在卷积运算优化和 WebAssembly 多线程支持等方面仍有提升空间,但它已经证明了在浏览器中构建完整机器学习生态的可行性。对于希望在不依赖后端的情况下实现实时交互式 AI 应用的开发者来说,这无疑开启了一个新的可能。
 
 
Back to Top