-
Notifications
You must be signed in to change notification settings - Fork 6
/
flake.nix
45 lines (40 loc) · 1.1 KB
/
flake.nix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
{
inputs = {
nixpkgs.url = "github:nixos/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs = { self, nixpkgs, flake-utils }:
flake-utils.lib.eachDefaultSystem (system:
let pkgs = nixpkgs.legacyPackages.${system};
in with pkgs.python3.pkgs;
let jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
in {
defaultPackage = buildPythonPackage {
pname = "torch2jax";
version = "0.0.1";
pyproject = true;
src = ./.;
propagatedBuildInputs = [ jax torch ];
nativeCheckInputs = [ jaxlib' pytestCheckHook torchvision ];
# torchvision downloads models into HOME.
preCheck = ''
export HOME=$(mktemp -d)
'';
pythonImportsCheck = [ "torch2jax" ];
};
devShell = pkgs.mkShell {
buildInputs = [
pkgs.act
pkgs.ruff
build
ipython
jax
jaxlib'
pytest
torch
torchvision
twine
];
};
});
}