Merge pull request #678 from kaizhang/master
[MACS.git] / test / test_HMMR_HMM.py
blob34ce99a5887cdd6c65ddb0bc20f2c1015c552e67
1 import unittest
2 import pytest
3 from MACS3.Signal.HMMR_HMM import *
4 import numpy as np
5 import numpy.testing as npt
7 # ------------------------------------
8 # Main function
9 # ------------------------------------
10 ''' This unittest is to check the ouputs of the hmm_training() and hmm_predict() functions
11 '''
13 # @pytest.mark.skip(reason="need to refine later")
14 class Test_HMM_train(unittest.TestCase):
15 def setUp( self ):
16 self.training_data = np.loadtxt("test/large_training_data.txt", delimiter="\t", dtype="float", usecols=(2,3,4,5)).tolist()
17 self.training_data_lengths = np.loadtxt('test/large_training_lengths.txt', dtype="int").tolist()
18 self.expected_converged = True
19 self.not_expected_covars = None
20 self.not_expected_means = None
21 self.not_expected_transmat = None
23 self.startprob = [0.01807016, 0.90153727, 0.08039257]
24 self.means = [[2.05560411e-01, 1.52959594e+00, 1.73568556e+00, 1.00019720e-04],
25 [1.84467806e-01, 1.46784946e+00, 1.67895745e+00, 1.00016654e-04],
26 [2.06402305e+00, 8.60140461e+00, 7.22907032e+00, 1.00847661e-04]]
27 self.covars = [[[ 1.19859257e-01, 5.33746506e-02, 3.99871507e-02, 1.49805047e-07],
28 [ 5.33746506e-02, 1.88774896e+00, 7.38204761e-01, 1.70902908e-07],
29 [ 3.99871507e-02, 7.38204761e-01, 2.34175176e+00, 1.75654357e-07],
30 [ 1.49805047e-07, 1.70902908e-07, 1.75654357e-07, 1.45312288e-07]],
31 [[ 1.06135330e-01, 4.16846792e-02, 3.24447289e-02, 1.30393434e-07],
32 [ 4.16846792e-02, 1.75537103e+00, 6.70848135e-01, 1.49425940e-07],
33 [ 3.24447289e-02, 6.70848135e-01, 2.22285392e+00, 1.52914017e-07],
34 [ 1.30393434e-07, 1.49425940e-07, 1.52914017e-07, 1.27205162e-07]],
35 [[ 5.94746590e+00, 5.24388615e+00, -5.33166471e-01, -1.47228883e-06],
36 [ 5.24388615e+00, 2.63945986e+01, 3.54212739e+00, -6.03892201e-06],
37 [-5.33166471e-01, 3.54212739e+00, 1.50231166e+01, 1.43141422e-05],
38 [-1.47228883e-06, -6.03892201e-06, 1.43141422e-05, 1.04240673e-07]]]
39 self.transmat =[[1.91958645e-03, 9.68166646e-01, 2.99137676e-02],
40 [8.52453717e-01, 1.46924953e-01, 6.21329356e-04],
41 [2.15432113e-02, 6.80080650e-05, 9.78388781e-01]]
42 self.n_features = 4
44 # for prediction
45 self.prediction_data = np.loadtxt("test/small_prediction_data.txt", delimiter="\t", dtype="float", usecols=(2,3,4,5)).tolist()
46 self.prediction_data_lengths = np.loadtxt('test/small_prediction_lengths.txt', dtype="int").tolist()
47 self.predictions = np.loadtxt('test/small_prediction_results.txt', delimiter="\t", dtype="float").tolist()
49 @pytest.mark.skip( reason="it may fail with different sklearn+hmmlearn" )
50 def test_training( self ):
51 # test hmm_training:
52 model = hmm_training(training_data = self.training_data, training_data_lengths = self.training_data_lengths, n_states = 3, random_seed = 12345, covar = 'full')
53 print(model.startprob_)
54 print(model.means_)
55 print(model.covars_)
56 print(model.transmat_)
57 print(model.n_features)
58 self.assertEqual( model.monitor_.converged, self.expected_converged )
59 self.assertNotEqual( model.covars_.tolist(), self.not_expected_covars )
60 self.assertNotEqual( model.means_.tolist(), self.not_expected_means )
61 self.assertNotEqual( model.transmat_.tolist(), self.not_expected_transmat )
62 npt.assert_allclose( model.startprob_.tolist(), self.startprob )
63 npt.assert_allclose(model.means_, self.means)
64 npt.assert_allclose(model.covars_, self.covars)
65 npt.assert_allclose(model.transmat_, self.transmat)
66 npt.assert_allclose(model.n_features, self.n_features)
68 @pytest.mark.skip( reason="it may fail with different sklearn+hmmlearn" )
69 def test_predict( self ):
70 # test hmm_predict
71 hmm_model = GaussianHMM( n_components=3, covariance_type='full' )
72 hmm_model.startprob_ = np.array(self.startprob)
73 hmm_model.transmat_ = np.array(self.transmat)
74 hmm_model.means_ = np.array(self.means)
75 hmm_model.covars_ = np.array(self.covars)
76 hmm_model.covariance_type = 'full'
77 hmm_model.n_features = self.n_features
78 predictions = hmm_predict( self.prediction_data, self.prediction_data_lengths, hmm_model )
80 ## This is to write the prediction results into a file for 'correct' answer
81 #with open("test/small_prediction_results.txt","w") as f:
82 # for x,y,z in predictions:
83 # f.write( str(x)+"\t"+str(y)+"\t"+str(z)+"\n")
85 npt.assert_allclose( predictions, self.predictions )