11 cudaSupport ? config.cudaSupport,
15 buildPythonPackage rec {
16 pname = "causal-conv1d";
20 src = fetchFromGitHub {
22 repo = "causal-conv1d";
23 rev = "refs/tags/v${version}";
24 hash = "sha256-p5x5u3zEmEMN3mWd88o3jmcpKUnovTvn7I9jIOj/ie0=";
33 nativeBuildInputs = [ which ];
36 lib.optionals cudaSupport (
39 cuda_cudart # cuda_runtime.h, -lcudart
41 libcusparse # cusparse.h
42 libcusolver # cusolverDn.h
53 # pytest tests not enabled due to nvidia GPU dependency
54 pythonImportsCheck = [ "causal_conv1d" ];
57 CAUSAL_CONV1D_FORCE_BUILD = "TRUE";
58 } // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; };
61 description = "Causal depthwise conv1d in CUDA with a PyTorch interface";
62 homepage = "https://github.com/Dao-AILab/causal-conv1d";
63 license = licenses.bsd3;
64 maintainers = with maintainers; [ cfhammill ];
65 # The package requires CUDA or ROCm, the ROCm build hasn't
66 # been completed or tested, so broken if not using cuda.
67 broken = !cudaSupport;