1 # Copyright 2014 Google Inc. All Rights Reserved.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
15 """Test routines to generate dummy certificates."""
32 def __init__(self
, ca_cert_path
, verify_cb
, port
, host_name
='foo.com',
34 self
.host_name
= host_name
35 self
.verify_cb
= verify_cb
36 self
.ca_cert_path
= ca_cert_path
38 self
.host_name
= host_name
40 self
.connection
= None
42 def run_request(self
):
43 context
= certutils
.get_ssl_context()
44 context
.set_verify(certutils
.VERIFY_PEER
, self
.verify_cb
) # Demand a cert
45 context
.use_certificate_file(self
.ca_cert_path
)
46 context
.load_verify_locations(self
.ca_cert_path
)
48 s
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
49 self
.connection
= certutils
.get_ssl_connection(context
, s
)
50 self
.connection
.connect((self
.host
, self
.port
))
51 self
.connection
.set_tlsext_host_name(self
.host_name
)
54 self
.connection
.send('\r\n\r\n')
56 self
.connection
.shutdown()
57 self
.connection
.close()
60 class Handler(BaseHTTPServer
.BaseHTTPRequestHandler
):
61 protocol_version
= 'HTTP/1.1' # override BaseHTTPServer setting
63 def handle_one_request(self
):
64 """Handle a single HTTP request."""
65 self
.raw_requestline
= self
.rfile
.readline(65537)
68 class WrappedErrorHandler(Handler
):
69 """Wraps handler to verify expected sslproxy errors are being raised."""
74 sslproxy
._SetUpUsingDummyCert
(self
)
75 except certutils
.Error
:
76 self
.server
.error_function
= certutils
.Error
80 self
.connection
.shutdown()
81 self
.connection
.close()
84 class DummyArchive(object):
90 class DummyFetch(object):
93 self
.http_archive
= DummyArchive()
96 class Server(BaseHTTPServer
.HTTPServer
):
99 def __init__(self
, ca_cert_path
, use_error_handler
=False, port
=0,
101 self
.ca_cert_path
= ca_cert_path
102 with
open(ca_cert_path
, 'r') as ca_file
:
103 self
.ca_cert_str
= ca_file
.read()
104 self
.http_archive_fetch
= DummyFetch()
105 if use_error_handler
:
106 self
.HANDLER
= WrappedErrorHandler
108 self
.HANDLER
= sslproxy
.wrap_handler(Handler
)
110 BaseHTTPServer
.HTTPServer
.__init
__(self
, (host
, port
), self
.HANDLER
)
112 raise RuntimeError('Could not start HTTPSServer on port %d: %s'
116 thread
= threading
.Thread(target
=self
.serve_forever
)
124 except KeyboardInterrupt:
127 def __exit__(self
, type_
, value_
, traceback_
):
130 def get_certificate(self
, host
):
131 return certutils
.generate_cert(self
.ca_cert_str
, '', host
)
134 class TestClient(unittest
.TestCase
):
138 self
._temp
_dir
= tempfile
.mkdtemp(prefix
='sslproxy_', dir='/tmp')
139 self
.ca_cert_path
= self
._temp
_dir
+ 'testCA.pem'
140 self
.cert_path
= self
._temp
_dir
+ 'testCA-cert.cer'
141 self
.wrong_ca_cert_path
= self
._temp
_dir
+ 'wrong.pem'
142 self
.wrong_cert_path
= self
._temp
_dir
+ 'wrong-cert.cer'
144 # Write both pem and cer files for certificates
145 certutils
.write_dummy_ca_cert(*certutils
.generate_dummy_ca_cert(),
146 cert_path
=self
.ca_cert_path
)
147 certutils
.write_dummy_ca_cert(*certutils
.generate_dummy_ca_cert(),
148 cert_path
=self
.ca_cert_path
)
152 shutil
.rmtree(self
._temp
_dir
)
154 def verify_cb(self
, conn
, cert
, errnum
, depth
, ok
):
155 """A callback that verifies the certificate authentication worked.
158 conn: Connection object
160 errnum: possible error number
162 ok: 1 if the authentication worked 0 if it didnt.
164 1 or 0 depending on if the verification worked
166 self
.assertFalse(cert
.has_expired())
167 self
.assertGreater(time
.strftime('%Y%m%d%H%M%SZ', time
.gmtime()),
168 cert
.get_notBefore())
171 def test_no_host(self
):
172 with
Server(self
.ca_cert_path
) as server
:
173 c
= Client(self
.cert_path
, self
.verify_cb
, server
.server_port
, '')
174 self
.assertRaises(certutils
.Error
, c
.run_request
)
176 def test_client_connection(self
):
177 with
Server(self
.ca_cert_path
) as server
:
178 c
= Client(self
.cert_path
, self
.verify_cb
, server
.server_port
, 'foo.com')
181 c
= Client(self
.cert_path
, self
.verify_cb
, server
.server_port
,
185 def test_wrong_cert(self
):
186 with
Server(self
.ca_cert_path
, True) as server
:
187 c
= Client(self
.wrong_cert_path
, self
.verify_cb
, server
.server_port
,
189 self
.assertRaises(certutils
.Error
, c
.run_request
)
192 if __name__
== '__main__':
193 signal
.signal(signal
.SIGINT
, signal
.SIG_DFL
) # Exit on Ctrl-C