1 # Copyright 2014 Google Inc. All Rights Reserved.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
15 """Test routines to generate dummy certificates."""
28 class Server(BaseHTTPServer
.HTTPServer
):
30 def __init__(self
, https_root_ca_cert_path
):
31 BaseHTTPServer
.HTTPServer
.__init
__(
32 self
, ('localhost', 0), BaseHTTPServer
.BaseHTTPRequestHandler
)
33 self
.socket
= ssl
.wrap_socket(
34 self
.socket
, certfile
=https_root_ca_cert_path
, server_side
=True,
35 do_handshake_on_connect
=False)
38 thread
= threading
.Thread(target
=self
.serve_forever
)
46 except KeyboardInterrupt:
49 def __exit__(self
, type_
, value_
, traceback_
):
53 class CertutilsTest(unittest
.TestCase
):
55 def _check_cert_file(self
, cert_file_path
, cert_str
, key_str
=None):
56 cert_load
= open(cert_file_path
, 'r').read()
58 expected_cert
= key_str
+ cert_str
60 expected_cert
= cert_str
61 self
.assertEqual(expected_cert
, cert_load
)
64 self
._temp
_dir
= tempfile
.mkdtemp(prefix
='certutils_', dir='/tmp')
68 shutil
.rmtree(self
._temp
_dir
)
70 def test_generate_dummy_ca_cert(self
):
71 subject
= 'testSubject'
72 c
, _
= certutils
.generate_dummy_ca_cert(subject
)
73 c
= certutils
.load_cert(c
)
74 self
.assertEqual(c
.get_subject().commonName
, subject
)
76 def test_get_host_cert(self
):
77 ca_cert_path
= os
.path
.join(self
._temp
_dir
, 'rootCA.pem')
79 certutils
.write_dummy_ca_cert(*certutils
.generate_dummy_ca_cert(issuer
),
80 cert_path
=ca_cert_path
)
82 with
Server(ca_cert_path
) as server
:
83 cert_str
= certutils
.get_host_cert('localhost', server
.server_port
)
84 cert
= certutils
.load_cert(cert_str
)
85 self
.assertEqual(issuer
, cert
.get_subject().commonName
)
87 def test_get_host_cert_gives_empty_for_bad_host(self
):
88 cert_str
= certutils
.get_host_cert('not_a_valid_host_name_2472341234234234')
89 self
.assertEqual('', cert_str
)
91 def test_write_dummy_ca_cert(self
):
92 base_path
= os
.path
.join(self
._temp
_dir
, 'testCA')
93 ca_cert_path
= base_path
+ '.pem'
94 cert_path
= base_path
+ '-cert.pem'
95 ca_cert_android
= base_path
+ '-cert.cer'
96 ca_cert_windows
= base_path
+ '-cert.p12'
98 self
.assertFalse(os
.path
.exists(ca_cert_path
))
99 self
.assertFalse(os
.path
.exists(cert_path
))
100 self
.assertFalse(os
.path
.exists(ca_cert_android
))
101 self
.assertFalse(os
.path
.exists(ca_cert_windows
))
102 c
, k
= certutils
.generate_dummy_ca_cert()
103 certutils
.write_dummy_ca_cert(c
, k
, ca_cert_path
)
105 self
._check
_cert
_file
(ca_cert_path
, c
, k
)
106 self
._check
_cert
_file
(cert_path
, c
)
107 self
._check
_cert
_file
(ca_cert_android
, c
)
108 self
.assertTrue(os
.path
.exists(ca_cert_windows
))
110 def test_generate_cert(self
):
111 ca_cert_path
= os
.path
.join(self
._temp
_dir
, 'testCA.pem')
112 issuer
= 'testIssuer'
113 certutils
.write_dummy_ca_cert(
114 *certutils
.generate_dummy_ca_cert(issuer
), cert_path
=ca_cert_path
)
116 with
open(ca_cert_path
, 'r') as root_file
:
117 root_string
= root_file
.read()
118 subject
= 'testSubject'
119 cert_string
= certutils
.generate_cert(
120 root_string
, '', subject
)
121 cert
= certutils
.load_cert(cert_string
)
122 self
.assertEqual(issuer
, cert
.get_issuer().commonName
)
123 self
.assertEqual(subject
, cert
.get_subject().commonName
)
125 with
open(ca_cert_path
, 'r') as ca_cert_file
:
126 ca_cert_str
= ca_cert_file
.read()
127 cert_string
= certutils
.generate_cert(ca_cert_str
, cert_string
,
129 cert
= certutils
.load_cert(cert_string
)
130 self
.assertEqual(issuer
, cert
.get_issuer().commonName
)
131 self
.assertEqual(subject
, cert
.get_subject().commonName
)
134 if __name__
== '__main__':