Objax是一个为JAX提供面向对象层的机器学习框架。
objax的Python项目详细描述
奥贾克斯
Tutorials |Install |Documentation |Philosophy
这不是一个官方支持的Google产品。在
Objax是一个开源的机器学习框架,由于 极简的面向对象设计和可读的代码库。 {a5——它的名字来自于一个流行的名字^- 框架。 Objax是由研究人员为研究人员设计的,注重简单性和可理解性。 它的用户应该能够方便地阅读、理解、扩展和修改它以满足他们的需要。在
这是Objax的开发者资源库,只有很少的用户文档 在这里,要获得完整的文档,请转到objax.readthedocs.io。在
您可以在该项目的子目录中找到自述,例如:
用户安装指南
使用pip
安装Objax,如下所示:
pip install --upgrade objax
Objax支持gpu,但是假设您已经有了一些CUDA版本 安装。以下是额外步骤:
^{pr2}$有用的环境配置
以下是一些有用的选项:
# Prevent JAX from taking the whole GPU memory# (useful if you want to run several programs on a single GPU)exportXLA_PYTHON_CLIENT_PREALLOCATE=false
测试安装
您可以通过运行以下代码来测试安装:
importjaximportobjaxprint(f'Number of GPUs {jax.device_count()}')x=objax.random.normal(shape=(100,4))m=objax.nn.Linear(nin=4,nout=5)print('Matrix product shape',m(x).shape)# (100, 5)x=objax.random.normal(shape=(100,3,32,32))m=objax.nn.Conv2D(nin=3,nout=4,k=3)print('Conv2D return shape',m(x).shape)# (100, 4, 32, 32)
通常,如果使用CUDA运行此程序时出错,则可能意味着 安装CUDA或CuDNN时出现问题。在
运行代码示例
克隆代码存储库:
git clone https://github.com/google/objax.git
cd objax/examples
开发人员文档
以下是关于 development setup 还有一个guide on adding new code。在
- 项目
标签: