3 ---@diagnostic disable: redefined-local
4 package
.path
= arg
[0]:sub(0, -arg
[0]:match("([^/\\]*)$"):len() - 1) .. "?.lua;" .. package
.path
5 local test
= require("test")
6 local a
= require("liba")
7 test
.dir(getmetatable(a
.regress_linear
))
9 local x
= { 0, 1, 2, 3, 4 }
10 local y
= { 1, 2, 3, 4, 5 }
12 local ctx
= a
.regress_linear
.new({ 1 })
13 test
.dir(getmetatable(ctx
))
14 assert(type(a
.regress_linear
.eval(ctx
, { 1 })) == "number")
15 assert(type(a
.regress_linear
.err(ctx
, x
, y
)) == "table")
16 assert(type(a
.regress_linear
.gd(ctx
, { 0 }, 1, 0.1)) == "userdata")
17 assert(type(a
.regress_linear
.sgd(ctx
, x
, y
, 0.1)) == "userdata")
18 assert(type(a
.regress_linear
.bgd(ctx
, x
, y
, 0.1)) == "userdata")
19 assert(type(a
.regress_linear
.mgd(ctx
, x
, y
, 1e-3, 1.0, 0.1)) == "number")
20 assert(type(a
.regress_linear
.zero(ctx
)) == "userdata")
21 assert(type(ctx
.eval(ctx
, { 1 })) == "number")
22 assert(type(ctx
.err(ctx
, x
, y
)) == "table")
23 assert(type(ctx
.gd(ctx
, { 0 }, 1, 0.1)) == "userdata")
24 assert(type(ctx
.sgd(ctx
, x
, y
, 0.1)) == "userdata")
25 assert(type(ctx
.bgd(ctx
, x
, y
, 0.1)) == "userdata")
26 assert(type(ctx
.mgd(ctx
, x
, y
, 1e-3, 1.0, 0.1)) == "number")
27 assert(type(ctx
.zero(ctx
)) == "userdata")
28 assert(type(ctx
:eval({ 1 })) == "number")
29 assert(type(ctx
:err(x
, y
)) == "table")
30 assert(type(ctx
:gd({ 0 }, 1, 0.1)) == "userdata")
31 assert(type(ctx
:sgd(x
, y
, 0.1)) == "userdata")
32 assert(type(ctx
:bgd(x
, y
, 0.1)) == "userdata")
33 assert(type(ctx
:mgd(x
, y
, 1e-3, 1.0, 0.1)) == "number")
34 assert(type(ctx
:zero()) == "userdata")
39 assert(type(ctx
.coef
) == "table")
40 assert(type(ctx
.bias
) == "number")
46 assert(ctx
.__newindex
)
47 ---@class a.regress_linear
48 ---@field __name string
49 ---@field __index table
50 ---@field __newindex table