Merge pull request #307098 from r-ryantm/auto-update/cilium-cli
[NixPkgs.git] / pkgs / development / python-modules / torchrl / default.nix
blob591e59302ea6ad022c5ed09991de55206c0eb926
1 { lib
2 , buildPythonPackage
3 , pythonOlder
4 , fetchFromGitHub
5 , ninja
6 , setuptools
7 , wheel
8 , which
9 , cloudpickle
10 , numpy
11 , torch
12 , ale-py
13 , gym
14 , pygame
15 , gymnasium
16 , mujoco
17 , moviepy
18 , git
19 , hydra-core
20 , tensorboard
21 , tqdm
22 , wandb
23 , packaging
24 , tensordict
25 , imageio
26 , pytest-rerunfailures
27 , pytestCheckHook
28 , pyyaml
29 , scipy
32 buildPythonPackage rec {
33   pname = "torchrl";
34   version = "0.3.1";
35   pyproject = true;
37   disabled = pythonOlder "3.8";
39   src = fetchFromGitHub {
40     owner = "pytorch";
41     repo = "rl";
42     rev = "refs/tags/v${version}";
43     hash = "sha256-lETW996IKPUGgZpe+cyzrXvVmDSwj5G4XFreFmGxReQ=";
44   };
46   nativeBuildInputs = [
47     ninja
48     setuptools
49     wheel
50     which
51   ];
53   propagatedBuildInputs = [
54     cloudpickle
55     numpy
56     packaging
57     tensordict
58     torch
59   ];
61   passthru.optional-dependencies = {
62     atari = [
63       ale-py
64       gym
65       pygame
66     ];
67     gym-continuous = [
68       gymnasium
69       mujoco
70     ];
71     rendering = [
72       moviepy
73     ];
74     utils = [
75       git
76       hydra-core
77       tensorboard
78       tqdm
79       wandb
80     ];
81   };
83   # torchrl needs to create a folder to store datasets
84   preBuild = ''
85     export D4RL_DATASET_DIR=$(mktemp -d)
86   '';
88   pythonImportsCheck = [
89     "torchrl"
90   ];
92   # We have to delete the source because otherwise it is used instead of the installed package.
93   preCheck = ''
94     rm -rf torchrl
96     export XDG_RUNTIME_DIR=$(mktemp -d)
97   '';
99   nativeCheckInputs = [
100     gymnasium
101     imageio
102     pytest-rerunfailures
103     pytestCheckHook
104     pyyaml
105     scipy
106   ]
107   ++ passthru.optional-dependencies.atari
108   ++ passthru.optional-dependencies.gym-continuous
109   ++ passthru.optional-dependencies.rendering;
111   disabledTests = [
112     # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
113     "test_vecenvs_env"
115     # ValueError: Can't write images with one color channel.
116     "test_log_video"
118     # Those tests require the ALE environments (provided by unpackaged shimmy)
119     "test_collector_env_reset"
120     "test_gym"
121     "test_gym_fake_td"
122     "test_recorder"
123     "test_recorder_load"
124     "test_rollout"
125     "test_parallel_trans_env_check"
126     "test_serial_trans_env_check"
127     "test_single_trans_env_check"
128     "test_td_creation_from_spec"
129     "test_trans_parallel_env_check"
130     "test_trans_serial_env_check"
131     "test_transform_env"
132   ];
134   meta = with lib; {
135     description = "A modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
136     homepage = "https://github.com/pytorch/rl";
137     changelog = "https://github.com/pytorch/rl/releases/tag/v${version}";
138     license = licenses.mit;
139     maintainers = with maintainers; [ GaetanLepage ];
140   };