change a_float to a_real
[liba.git] / lua / test / regress_linear.lua
blob31225f945e3fd045e6a3cb94c27449918537258a
1 #!/usr/bin/env lua
3 ---@diagnostic disable: redefined-local
4 package.path = arg[0]:sub(0, -arg[0]:match("([^/\\]*)$"):len() - 1) .. "?.lua;" .. package.path
5 local test = require("test")
6 local a = require("liba")
7 test.dir(getmetatable(a.regress_linear))
9 local x = { 0, 1, 2, 3, 4 }
10 local y = { 1, 2, 3, 4, 5 }
12 local ctx = a.regress_linear.new({ 1 })
13 test.dir(getmetatable(ctx))
14 assert(type(a.regress_linear.eval(ctx, { 1 })) == "number")
15 assert(type(a.regress_linear.err(ctx, x, y)) == "table")
16 assert(type(a.regress_linear.gd(ctx, { 0 }, 1, 0.1)) == "userdata")
17 assert(type(a.regress_linear.sgd(ctx, x, y, 0.1)) == "userdata")
18 assert(type(a.regress_linear.bgd(ctx, x, y, 0.1)) == "userdata")
19 assert(type(a.regress_linear.mgd(ctx, x, y, 1e-3, 1.0, 0.1)) == "number")
20 assert(type(a.regress_linear.zero(ctx)) == "userdata")
21 assert(type(ctx.eval(ctx, { 1 })) == "number")
22 assert(type(ctx.err(ctx, x, y)) == "table")
23 assert(type(ctx.gd(ctx, { 0 }, 1, 0.1)) == "userdata")
24 assert(type(ctx.sgd(ctx, x, y, 0.1)) == "userdata")
25 assert(type(ctx.bgd(ctx, x, y, 0.1)) == "userdata")
26 assert(type(ctx.mgd(ctx, x, y, 1e-3, 1.0, 0.1)) == "number")
27 assert(type(ctx.zero(ctx)) == "userdata")
28 assert(type(ctx:eval({ 1 })) == "number")
29 assert(type(ctx:err(x, y)) == "table")
30 assert(type(ctx:gd({ 0 }, 1, 0.1)) == "userdata")
31 assert(type(ctx:sgd(x, y, 0.1)) == "userdata")
32 assert(type(ctx:bgd(x, y, 0.1)) == "userdata")
33 assert(type(ctx:mgd(x, y, 1e-3, 1.0, 0.1)) == "number")
34 assert(type(ctx:zero()) == "userdata")
35 test.dir(ctx.coef)
36 test.dir(ctx.bias)
37 ctx.coef = { 1 }
38 ctx.bias = 1
39 assert(type(ctx.coef) == "table")
40 assert(type(ctx.bias) == "number")
41 ctx.__name = nil
42 assert(ctx.__name)
43 ctx.__index = nil
44 assert(ctx.__index)
45 ctx.__newindex = nil
46 assert(ctx.__newindex)
47 ---@class a.regress_linear
48 ---@field __name string
49 ---@field __index table
50 ---@field __newindex table