base16-schemes: unstable-2024-06-21 -> unstable-2024-11-12
[NixPkgs.git] / pkgs / by-name / lo / local-ai / tests.nix
blob06639536ff750efe32d258e08619c7a9a4b8830d
2   self,
3   lib,
4   testers,
5   fetchzip,
6   fetchurl,
7   writers,
8   symlinkJoin,
9   jq,
10   prom2json,
12 let
13   common-config =
14     { config, ... }:
15     {
16       imports = [ ./module.nix ];
17       services.local-ai = {
18         enable = true;
19         package = self;
20         threads = config.virtualisation.cores;
21         logLevel = "debug";
22       };
23     };
25   inherit (self.lib) genModels;
28   version = testers.testVersion {
29     package = self;
30     version = "v" + self.version;
31     command = "local-ai --help";
32   };
34   health = testers.runNixOSTest {
35     name = self.name + "-health";
36     nodes.machine = {
37       imports = [ common-config ];
38       virtualisation.memorySize = 2048;
39     };
40     testScript =
41       let
42         port = "8080";
43       in
44       ''
45         machine.wait_for_open_port(${port})
46         machine.succeed("curl -f http://localhost:${port}/readyz")
48         machine.succeed("${prom2json}/bin/prom2json http://localhost:${port}/metrics > metrics.json")
49         machine.copy_from_vm("metrics.json")
50       '';
51   };
54 // lib.optionalAttrs (!self.features.with_cublas) {
55   # https://localai.io/features/embeddings/#bert-embeddings
56   bert =
57     let
58       model = "embedding";
59       model-configs.${model} = {
60         # Note: q4_0 and q4_1 models can not be loaded
61         parameters.model = fetchurl {
62           url = "https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-f16.bin";
63           hash = "sha256-nBlbJFOk/vYKT2vjqIo5IRNmIU32SYpP5IhcniIxT1A=";
64         };
65         backend = "bert-embeddings";
66         embeddings = true;
67       };
69       models = genModels model-configs;
71       requests.request = {
72         inherit model;
73         input = "Your text string goes here";
74       };
75     in
76     testers.runNixOSTest {
77       name = self.name + "-bert";
78       nodes.machine = {
79         imports = [ common-config ];
80         virtualisation.cores = 2;
81         virtualisation.memorySize = 4096;
82         services.local-ai.models = models;
83       };
84       passthru = {
85         inherit models requests;
86       };
87       testScript =
88         let
89           port = "8080";
90         in
91         ''
92           machine.wait_for_open_port(${port})
93           machine.succeed("curl -f http://localhost:${port}/readyz")
94           machine.succeed("curl -f http://localhost:${port}/v1/models --output models.json")
95           machine.succeed("${jq}/bin/jq --exit-status 'debug | .data[].id == \"${model}\"' models.json")
97           machine.succeed("curl -f http://localhost:${port}/embeddings --json @${writers.writeJSON "request.json" requests.request} --output embeddings.json")
98           machine.copy_from_vm("embeddings.json")
99           machine.succeed("${jq}/bin/jq --exit-status 'debug | .model == \"${model}\"' embeddings.json")
101           machine.succeed("${prom2json}/bin/prom2json http://localhost:${port}/metrics > metrics.json")
102           machine.copy_from_vm("metrics.json")
103         '';
104     };
107 // lib.optionalAttrs (!self.features.with_cublas && !self.features.with_clblas) {
108   # https://localai.io/docs/getting-started/manual/
109   llama =
110     let
111       model = "gpt-3.5-turbo";
113       # https://localai.io/advanced/#full-config-model-file-reference
114       model-configs.${model} = rec {
115         context_size = 16 * 1024; # 128kb is possible, but needs 16GB RAM
116         backend = "llama-cpp";
117         parameters = {
118           # https://ai.meta.com/blog/meta-llama-3-1/
119           model = fetchurl {
120             url = "https://huggingface.co/lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf";
121             hash = "sha256-8r4+GiOcEsnz8BqWKxH7KAf4Ay/bY7ClUC6kLd71XkQ=";
122           };
123           # defaults from:
124           # https://deepinfra.com/meta-llama/Meta-Llama-3.1-8B-Instruct
125           temperature = 0.7;
126           top_p = 0.9;
127           top_k = 0;
128           # following parameter leads to outputs like: !!!!!!!!!!!!!!!!!!!
129           #repeat_penalty = 1;
130           presence_penalty = 0;
131           frequency_penalty = 0;
132           max_tokens = 100;
133         };
134         stopwords = [ "<|eot_id|>" ];
135         template = {
136           # Templates implement following specifications
137           # https://github.com/meta-llama/llama3/tree/main?tab=readme-ov-file#instruction-tuned-models
138           # ... and are insprired by:
139           # https://github.com/mudler/LocalAI/blob/master/embedded/models/llama3-instruct.yaml
140           #
141           # The rules for template evaluateion are defined here:
142           # https://pkg.go.dev/text/template
143           chat_message = ''
144             <|start_header_id|>{{.RoleName}}<|end_header_id|>
146             {{.Content}}${builtins.head stopwords}'';
148           chat = "{{.Input}}<|start_header_id|>assistant<|end_header_id|>";
150           completion = "{{.Input}}";
151         };
152       };
154       models = genModels model-configs;
156       requests = {
157         # https://localai.io/features/text-generation/#chat-completions
158         chat-completions = {
159           inherit model;
160           messages = [
161             {
162               role = "user";
163               content = "1 + 2 = ?";
164             }
165           ];
166         };
167         # https://localai.io/features/text-generation/#edit-completions
168         edit-completions = {
169           inherit model;
170           instruction = "rephrase";
171           input = "Black cat jumped out of the window";
172           max_tokens = 50;
173         };
174         # https://localai.io/features/text-generation/#completions
175         completions = {
176           inherit model;
177           prompt = "A long time ago in a galaxy far, far away";
178         };
179       };
180     in
181     testers.runNixOSTest {
182       name = self.name + "-llama";
183       nodes.machine = {
184         imports = [ common-config ];
185         virtualisation.cores = 4;
186         virtualisation.memorySize = 8192;
187         services.local-ai.models = models;
188         # TODO: Add test case parallel requests
189         services.local-ai.parallelRequests = 2;
190       };
191       passthru = {
192         inherit models requests;
193       };
194       testScript =
195         let
196           port = "8080";
197         in
198         ''
199           machine.wait_for_open_port(${port})
200           machine.succeed("curl -f http://localhost:${port}/readyz")
201           machine.succeed("curl -f http://localhost:${port}/v1/models --output models.json")
202           machine.succeed("${jq}/bin/jq --exit-status 'debug | .data[].id == \"${model}\"' models.json")
204           machine.succeed("curl -f http://localhost:${port}/v1/chat/completions --json @${writers.writeJSON "request-chat-completions.json" requests.chat-completions} --output chat-completions.json")
205           machine.copy_from_vm("chat-completions.json")
206           machine.succeed("${jq}/bin/jq --exit-status 'debug | .object == \"chat.completion\"' chat-completions.json")
207           machine.succeed("${jq}/bin/jq --exit-status 'debug | .choices | first.message.content | split(\" \") | last | tonumber == 3' chat-completions.json")
209           machine.succeed("curl -f http://localhost:${port}/v1/edits --json @${writers.writeJSON "request-edit-completions.json" requests.edit-completions} --output edit-completions.json")
210           machine.copy_from_vm("edit-completions.json")
211           machine.succeed("${jq}/bin/jq --exit-status 'debug | .object == \"edit\"' edit-completions.json")
212           machine.succeed("${jq}/bin/jq --exit-status '.usage.completion_tokens | debug == ${toString requests.edit-completions.max_tokens}' edit-completions.json")
214           machine.succeed("curl -f http://localhost:${port}/v1/completions --json @${writers.writeJSON "request-completions.json" requests.completions} --output completions.json")
215           machine.copy_from_vm("completions.json")
216           machine.succeed("${jq}/bin/jq --exit-status 'debug | .object ==\"text_completion\"' completions.json")
217           machine.succeed("${jq}/bin/jq --exit-status '.usage.completion_tokens | debug == ${
218             toString model-configs.${model}.parameters.max_tokens
219           }' completions.json")
221           machine.succeed("${prom2json}/bin/prom2json http://localhost:${port}/metrics > metrics.json")
222           machine.copy_from_vm("metrics.json")
223         '';
224     };
228   lib.optionalAttrs
229     (self.features.with_tts && !self.features.with_cublas && !self.features.with_clblas)
230     {
231       # https://localai.io/features/text-to-audio/#piper
232       tts =
233         let
234           model-stt = "whisper-en";
235           model-configs.${model-stt} = {
236             backend = "whisper";
237             parameters.model = fetchurl {
238               url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin";
239               hash = "sha256-x3xXZvHO8JtrfUfyG1Rsvd1BV4hrO11tT3CekeZsfCs=";
240             };
241           };
243           model-tts = "piper-en";
244           model-configs.${model-tts} = {
245             backend = "piper";
246             parameters.model = "en-us-danny-low.onnx";
247           };
249           models =
250             let
251               models = genModels model-configs;
252             in
253             symlinkJoin {
254               inherit (models) name;
255               paths = [
256                 models
257                 (fetchzip {
258                   url = "https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-danny-low.tar.gz";
259                   hash = "sha256-5wf+6H5HeQY0qgdqnAG1vSqtjIFM9lXH53OgouuPm0M=";
260                   stripRoot = false;
261                 })
262               ];
263             };
265           requests.request = {
266             model = model-tts;
267             input = "Hello, how are you?";
268           };
269         in
270         testers.runNixOSTest {
271           name = self.name + "-tts";
272           nodes.machine = {
273             imports = [ common-config ];
274             virtualisation.cores = 2;
275             services.local-ai.models = models;
276           };
277           passthru = {
278             inherit models requests;
279           };
280           testScript =
281             let
282               port = "8080";
283             in
284             ''
285               machine.wait_for_open_port(${port})
286               machine.succeed("curl -f http://localhost:${port}/readyz")
287               machine.succeed("curl -f http://localhost:${port}/v1/models --output models.json")
288               machine.succeed("${jq}/bin/jq --exit-status 'debug' models.json")
290               machine.succeed("curl -f http://localhost:${port}/tts --json @${writers.writeJSON "request.json" requests.request} --output out.wav")
291               machine.copy_from_vm("out.wav")
293               machine.succeed("curl -f http://localhost:${port}/v1/audio/transcriptions --header 'Content-Type: multipart/form-data' --form file=@out.wav --form model=${model-stt} --output transcription.json")
294               machine.copy_from_vm("transcription.json")
295               machine.succeed("${jq}/bin/jq --exit-status 'debug | .segments | first.text == \"${requests.request.input}\"' transcription.json")
297               machine.succeed("${prom2json}/bin/prom2json http://localhost:${port}/metrics > metrics.json")
298               machine.copy_from_vm("metrics.json")
299             '';
300         };
301     }