btrbk: add mainProgram (#356350)
[NixPkgs.git] / pkgs / development / python-modules / finetuning-scheduler / default.nix
blob808c8e1ab6962e03dce6dc8cbcc2e0c94ed833ae
2   stdenv,
3   lib,
4   buildPythonPackage,
5   fetchFromGitHub,
6   setuptools,
7   pythonOlder,
8   pytestCheckHook,
9   torch,
10   pytorch-lightning,
13 buildPythonPackage rec {
14   pname = "finetuning-scheduler";
15   version = "2.4.0";
16   pyproject = true;
18   disabled = pythonOlder "3.9";
20   src = fetchFromGitHub {
21     owner = "speediedan";
22     repo = "finetuning-scheduler";
23     rev = "refs/tags/v${version}";
24     hash = "sha256-uSFGZseSJv519LpaddO6yP6AsIMZutEA0Y7Yr+mEWTQ=";
25   };
27   build-system = [ setuptools ];
29   dependencies = [
30     pytorch-lightning
31     torch
32   ];
34   # needed while lightning is installed as package `pytorch-lightning` rather than`lightning`:
35   env.PACKAGE_NAME = "pytorch";
37   nativeCheckInputs = [ pytestCheckHook ];
38   pytestFlagsArray = [ "tests" ];
39   disabledTests =
40     # torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
41     # LoweringException: ImportError: cannot import name 'triton_key' from 'triton.compiler.compiler'
42     lib.optionals (pythonOlder "3.12") [
43       "test_fts_dynamo_enforce_p0"
44       "test_fts_dynamo_resume"
45       "test_fts_dynamo_intrafit"
46     ]
47     ++ lib.optionals (stdenv.hostPlatform.isAarch64 && stdenv.hostPlatform.isLinux) [
48       # slightly exceeds numerical tolerance on aarch64-linux:
49       "test_fts_frozen_bn_track_running_stats"
50     ];
52   pythonImportsCheck = [ "finetuning_scheduler" ];
54   meta = {
55     description = "PyTorch Lightning extension for foundation model experimentation with flexible fine-tuning schedules";
56     homepage = "https://finetuning-scheduler.readthedocs.io";
57     changelog = "https://github.com/speediedan/finetuning-scheduler/blob/${src.rev}/CHANGELOG.md";
58     license = lib.licenses.asl20;
59     maintainers = with lib.maintainers; [ bcdarwin ];
60     # "No module named 'torch._C._distributed_c10d'; 'torch._C' is not a package" at import time:
61     broken = stdenv.hostPlatform.isDarwin;
62   };