Merge pull request #307098 from r-ryantm/auto-update/cilium-cli
[NixPkgs.git] / pkgs / development / python-modules / torchinfo / default.nix
blob59b7e99c9dfca93ffa987d253e73f7da260c9bff
1 { lib
2 , buildPythonPackage
3 , fetchFromGitHub
4 , fetchpatch
5 , pythonOlder
6 , torch
7 , torchvision
8 , pytestCheckHook
9 , transformers
12 buildPythonPackage rec {
13   pname = "torchinfo";
14   version = "1.8.0";
15   format = "setuptools";
17   disabled = pythonOlder "3.7";
19   src = fetchFromGitHub {
20     owner = "TylerYep";
21     repo = "torchinfo";
22     rev = "refs/tags/v${version}";
23     hash = "sha256-pPjg498aT8y4b4tqIzNxxKyobZX01u+66ScS/mee51Q=";
24   };
26   patches = [
27     (fetchpatch {  # Add support for Python 3.11 and pytorch 2.1
28       url = "https://github.com/TylerYep/torchinfo/commit/c74784c71c84e62bcf56664653b7f28d72a2ee0d.patch";
29       hash = "sha256-xSSqs0tuFpdMXUsoVv4sZLCeVnkK6pDDhX/Eobvn5mw=";
30       includes = [
31         "torchinfo/model_statistics.py"
32       ];
33     })
34   ];
36   propagatedBuildInputs = [
37     torch
38     torchvision
39   ];
41   nativeCheckInputs = [
42     pytestCheckHook
43     transformers
44   ];
46   preCheck = ''
47     export HOME=$(mktemp -d)
48   '';
50   disabledTests = [
51     # Skip as it downloads pretrained weights (require network access)
52     "test_eval_order_doesnt_matter"
53     "test_flan_t5_small"
54     # AssertionError in output
55     "test_google"
56     # "addmm_impl_cpu_" not implemented for 'Half'
57     "test_input_size_half_precision"
58   ];
60   disabledTestPaths = [
61     # Test requires network access
62     "tests/torchinfo_xl_test.py"
63   ];
65   pythonImportsCheck = [
66     "torchinfo"
67   ];
69   meta = with lib; {
70     description = "API to visualize pytorch models";
71     homepage = "https://github.com/TylerYep/torchinfo";
72     license = licenses.mit;
73     maintainers = with maintainers; [ petterstorvik ];
74   };