Merge branch 'maint-0.4.8'
[tor.git] / src / test / ntor_ref.py
blobe3307430e1ad7702e1cbb20187ec10fa6154875a
1 #!/usr/bin/python
2 # Copyright 2012-2019, The Tor Project, Inc
3 # See LICENSE for licensing information
5 """
6 ntor_ref.py
9 This module is a reference implementation for the "ntor" protocol
10 s proposed by Goldberg, Stebila, and Ustaoglu and as instantiated in
11 Tor Proposal 216.
13 It's meant to be used to validate Tor's ntor implementation. It
14 requirs the curve25519 python module from the curve25519-donna
15 package.
17 *** DO NOT USE THIS IN PRODUCTION. ***
19 commands:
21 gen_kdf_vectors: Print out some test vectors for the RFC5869 KDF.
22 timing: Print a little timing information about this implementation's
23 handshake.
24 self-test: Try handshaking with ourself; make sure we can.
25 test-tor: Handshake with tor's ntor implementation via the program
26 src/test/test-ntor-cl; make sure we can.
28 """
30 # Future imports for Python 2.7, mandatory in 3.0
31 from __future__ import division
32 from __future__ import print_function
33 from __future__ import unicode_literals
35 import binascii
36 try:
37 import curve25519
38 curve25519mod = curve25519.keys
39 except ImportError:
40 curve25519 = None
41 import slownacl_curve25519
42 curve25519mod = slownacl_curve25519
44 import hashlib
45 import hmac
46 import subprocess
47 import sys
49 # **********************************************************************
50 # Helpers and constants
52 def HMAC(key,msg):
53 "Return the HMAC-SHA256 of 'msg' using the key 'key'."
54 H = hmac.new(key, b"", hashlib.sha256)
55 H.update(msg)
56 return H.digest()
58 def H(msg,tweak):
59 """Return the hash of 'msg' using tweak 'tweak'. (In this version of ntor,
60 the tweaked hash is just HMAC with the tweak as the key.)"""
61 return HMAC(key=tweak,
62 msg=msg)
64 def keyid(k):
65 """Return the 32-byte key ID of a public key 'k'. (Since we're
66 using curve25519, we let k be its own keyid.)
67 """
68 return k.serialize()
70 NODE_ID_LENGTH = 20
71 KEYID_LENGTH = 32
72 G_LENGTH = 32
73 H_LENGTH = 32
75 PROTOID = b"ntor-curve25519-sha256-1"
76 M_EXPAND = PROTOID + b":key_expand"
77 T_MAC = PROTOID + b":mac"
78 T_KEY = PROTOID + b":key_extract"
79 T_VERIFY = PROTOID + b":verify"
81 def H_mac(msg): return H(msg, tweak=T_MAC)
82 def H_verify(msg): return H(msg, tweak=T_VERIFY)
84 class PrivateKey(curve25519mod.Private):
85 """As curve25519mod.Private, but doesn't regenerate its public key
86 every time you ask for it.
87 """
88 def __init__(self):
89 curve25519mod.Private.__init__(self)
90 self._memo_public = None
92 def get_public(self):
93 if self._memo_public is None:
94 self._memo_public = curve25519mod.Private.get_public(self)
96 return self._memo_public
98 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
100 if sys.version < '3':
101 def int2byte(i):
102 return chr(i)
103 else:
104 def int2byte(i):
105 return bytes([i])
107 def kdf_rfc5869(key, salt, info, n):
109 prk = HMAC(key=salt, msg=key)
111 out = b""
112 last = b""
113 i = 1
114 while len(out) < n:
115 m = last + info + int2byte(i)
116 last = h = HMAC(key=prk, msg=m)
117 out += h
118 i = i + 1
119 return out[:n]
121 def kdf_ntor(key, n):
122 return kdf_rfc5869(key, T_KEY, M_EXPAND, n)
124 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
126 def client_part1(node_id, pubkey_B):
127 """Initial handshake, client side.
129 From the specification:
131 <<To send a create cell, the client generates a keypair x,X =
132 KEYGEN(), and sends a CREATE cell with contents:
134 NODEID: ID -- ID_LENGTH bytes
135 KEYID: KEYID(B) -- H_LENGTH bytes
136 CLIENT_PK: X -- G_LENGTH bytes
139 Takes node_id -- a digest of the server's identity key,
140 pubkey_B -- a public key for the server.
141 Returns a tuple of (client secret key x, client->server message)"""
143 assert len(node_id) == NODE_ID_LENGTH
145 key_id = keyid(pubkey_B)
146 seckey_x = PrivateKey()
147 pubkey_X = seckey_x.get_public().serialize()
149 message = node_id + key_id + pubkey_X
151 assert len(message) == NODE_ID_LENGTH + H_LENGTH + H_LENGTH
152 return seckey_x , message
154 def hash_nil(x):
155 """Identity function: if we don't pass a hash function that does nothing,
156 the curve25519 python lib will try to sha256 it for us."""
157 return x
159 def bad_result(r):
160 """Helper: given a result of multiplying a public key by a private key,
161 return True iff one of the inputs was broken"""
162 assert len(r) == 32
163 return r == '\x00'*32
165 def server(seckey_b, my_node_id, message, keyBytes=72):
166 """Handshake step 2, server side.
168 From the spec:
171 The server generates a keypair of y,Y = KEYGEN(), and computes
173 secret_input = EXP(X,y) | EXP(X,b) | ID | B | X | Y | PROTOID
174 KEY_SEED = H(secret_input, t_key)
175 verify = H(secret_input, t_verify)
176 auth_input = verify | ID | B | Y | X | PROTOID | "Server"
178 The server sends a CREATED cell containing:
180 SERVER_PK: Y -- G_LENGTH bytes
181 AUTH: H(auth_input, t_mac) -- H_LENGTH byets
184 Takes seckey_b -- the server's secret key
185 my_node_id -- the servers's public key digest,
186 message -- a message from a client
187 keybytes -- amount of key material to generate
189 Returns a tuple of (key material, sever->client reply), or None on
190 error.
193 assert len(message) == NODE_ID_LENGTH + H_LENGTH + H_LENGTH
195 if my_node_id != message[:NODE_ID_LENGTH]:
196 return None
198 badness = (keyid(seckey_b.get_public()) !=
199 message[NODE_ID_LENGTH:NODE_ID_LENGTH+H_LENGTH])
201 pubkey_X = curve25519mod.Public(message[NODE_ID_LENGTH+H_LENGTH:])
202 seckey_y = PrivateKey()
203 pubkey_Y = seckey_y.get_public()
204 pubkey_B = seckey_b.get_public()
205 xy = seckey_y.get_shared_key(pubkey_X, hash_nil)
206 xb = seckey_b.get_shared_key(pubkey_X, hash_nil)
208 # secret_input = EXP(X,y) | EXP(X,b) | ID | B | X | Y | PROTOID
209 secret_input = (xy + xb + my_node_id +
210 pubkey_B.serialize() +
211 pubkey_X.serialize() +
212 pubkey_Y.serialize() +
213 PROTOID)
215 verify = H_verify(secret_input)
217 # auth_input = verify | ID | B | Y | X | PROTOID | "Server"
218 auth_input = (verify +
219 my_node_id +
220 pubkey_B.serialize() +
221 pubkey_Y.serialize() +
222 pubkey_X.serialize() +
223 PROTOID +
224 b"Server")
226 msg = pubkey_Y.serialize() + H_mac(auth_input)
228 badness += bad_result(xb)
229 badness += bad_result(xy)
231 if badness:
232 return None
234 keys = kdf_ntor(secret_input, keyBytes)
236 return keys, msg
238 def client_part2(seckey_x, msg, node_id, pubkey_B, keyBytes=72):
239 """Handshake step 3: client side again.
241 From the spec:
244 The client then checks Y is in G^* [see NOTE below], and computes
246 secret_input = EXP(Y,x) | EXP(B,x) | ID | B | X | Y | PROTOID
247 KEY_SEED = H(secret_input, t_key)
248 verify = H(secret_input, t_verify)
249 auth_input = verify | ID | B | Y | X | PROTOID | "Server"
251 The client verifies that AUTH == H(auth_input, t_mac).
254 Takes seckey_x -- the secret key we generated in step 1.
255 msg -- the message from the server.
256 node_id -- the node_id we used in step 1.
257 server_key -- the same public key we used in step 1.
258 keyBytes -- the number of bytes we want to generate
259 Returns key material, or None on error
262 assert len(msg) == G_LENGTH + H_LENGTH
264 pubkey_Y = curve25519mod.Public(msg[:G_LENGTH])
265 their_auth = msg[G_LENGTH:]
267 pubkey_X = seckey_x.get_public()
269 yx = seckey_x.get_shared_key(pubkey_Y, hash_nil)
270 bx = seckey_x.get_shared_key(pubkey_B, hash_nil)
273 # secret_input = EXP(Y,x) | EXP(B,x) | ID | B | X | Y | PROTOID
274 secret_input = (yx + bx + node_id +
275 pubkey_B.serialize() +
276 pubkey_X.serialize() +
277 pubkey_Y.serialize() + PROTOID)
279 verify = H_verify(secret_input)
281 # auth_input = verify | ID | B | Y | X | PROTOID | "Server"
282 auth_input = (verify + node_id +
283 pubkey_B.serialize() +
284 pubkey_Y.serialize() +
285 pubkey_X.serialize() + PROTOID +
286 b"Server")
288 my_auth = H_mac(auth_input)
290 badness = my_auth != their_auth
291 badness |= bad_result(yx) + bad_result(bx)
293 if badness:
294 return None
296 return kdf_ntor(secret_input, keyBytes)
298 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
300 def demo(node_id=b"iToldYouAboutStairs.", server_key=PrivateKey()):
302 Try to handshake with ourself.
304 x, create = client_part1(node_id, server_key.get_public())
305 skeys, created = server(server_key, node_id, create)
306 ckeys = client_part2(x, created, node_id, server_key.get_public())
307 assert len(skeys) == 72
308 assert len(ckeys) == 72
309 assert skeys == ckeys
310 print("OK")
312 # ======================================================================
313 def timing():
315 Use Python's timeit module to see how fast this nonsense is
317 import timeit
318 t = timeit.Timer(stmt="ntor_ref.demo(N,SK)",
319 setup="import ntor_ref,curve25519;N='ABCD'*5;SK=ntor_ref.PrivateKey()")
320 print(t.timeit(number=1000))
322 # ======================================================================
324 def kdf_vectors():
326 Generate some vectors to check our KDF.
328 import binascii
329 def kdf_vec(inp):
330 k = kdf_rfc5869(inp, T_KEY, M_EXPAND, 100)
331 print(repr(inp), "\n\""+ binascii.b2a_hex(k)+ "\"")
332 kdf_vec("")
333 kdf_vec("Tor")
334 kdf_vec("AN ALARMING ITEM TO FIND ON YOUR CREDIT-RATING STATEMENT")
336 # ======================================================================
339 def test_tor():
341 Call the test-ntor-cl command-line program to make sure we can
342 interoperate with Tor's ntor program
344 if sys.version_info[0] >= 3:
345 enhex=lambda s: binascii.b2a_hex(s).decode("ascii")
346 else:
347 enhex=lambda s: binascii.b2a_hex(s)
348 dehex=lambda s: binascii.a2b_hex(s.strip())
350 PROG = "./src/test/test-ntor-cl"
351 def tor_client1(node_id, pubkey_B):
352 " returns (msg, state) "
353 p = subprocess.Popen([PROG, "client1", enhex(node_id),
354 enhex(pubkey_B.serialize())],
355 stdout=subprocess.PIPE)
356 return map(dehex, p.stdout.readlines())
357 def tor_server1(seckey_b, node_id, msg, n):
358 " returns (msg, keys) "
359 p = subprocess.Popen([PROG, "server1", enhex(seckey_b.serialize()),
360 enhex(node_id), enhex(msg), str(n)],
361 stdout=subprocess.PIPE)
362 return map(dehex, p.stdout.readlines())
363 def tor_client2(state, msg, n):
364 " returns (keys,) "
365 p = subprocess.Popen([PROG, "client2", enhex(state),
366 enhex(msg), str(n)],
367 stdout=subprocess.PIPE)
368 return map(dehex, p.stdout.readlines())
371 node_id = b"thisisatornodeid$#%^"
372 seckey_b = PrivateKey()
373 pubkey_B = seckey_b.get_public()
375 # Do a pure-Tor handshake
376 c2s_msg, c_state = tor_client1(node_id, pubkey_B)
377 s2c_msg, s_keys = tor_server1(seckey_b, node_id, c2s_msg, 90)
378 c_keys, = tor_client2(c_state, s2c_msg, 90)
379 assert c_keys == s_keys
380 assert len(c_keys) == 90
382 # Try a mixed handshake with Tor as the client
383 c2s_msg, c_state = tor_client1(node_id, pubkey_B)
384 s_keys, s2c_msg = server(seckey_b, node_id, c2s_msg, 90)
385 c_keys, = tor_client2(c_state, s2c_msg, 90)
386 assert c_keys == s_keys
387 assert len(c_keys) == 90
389 # Now do a mixed handshake with Tor as the server
390 c_x, c2s_msg = client_part1(node_id, pubkey_B)
391 s2c_msg, s_keys = tor_server1(seckey_b, node_id, c2s_msg, 90)
392 c_keys = client_part2(c_x, s2c_msg, node_id, pubkey_B, 90)
393 assert c_keys == s_keys
394 assert len(c_keys) == 90
396 print("OK")
398 # ======================================================================
400 if __name__ == '__main__':
401 if len(sys.argv) < 2:
402 print(__doc__)
403 elif sys.argv[1] == 'gen_kdf_vectors':
404 kdf_vectors()
405 elif sys.argv[1] == 'timing':
406 timing()
407 elif sys.argv[1] == 'self-test':
408 demo()
409 elif sys.argv[1] == 'test-tor':
410 test_tor()
412 else:
413 print(__doc__)