reformatting to 80 columns
[secretsharing.git] / lib / secretsharing / shamir.rb
blobddebae1f5f09ce621021066a93004329507c99a5
1 require 'openssl'
2 require 'digest/sha1'
3 require 'base64'
5 module SecretSharing
6         # The SecretSharing::Shamir class can be used to share random
7         # secrets between n people, so that k < n people can recover the
8         # secret, but k-1 people learn nothing (in an information-theoretical
9         # sense) about the secret.
10         #
11         # For a theoretical background, see 
12         # http://www.cs.tau.ac.il/~bchor/Shamir.html or
13         # http://en.wikipedia.org/wiki/Secret_sharing#Shamir.27s_scheme
14         #
15         # To share a secret, create a new SecretSharing::Shamir object and
16         # then call the create_random_secret() method. The secret is now in
17         # the secret attribute and the shares are an array in the shares attribute.
18         #
19         # To recover a secret, create a SecretSharing::Shamir object and
20         # add the necessary shares to it using the '<<' method. Once enough
21         # shares have been added, the secret can be recovered in the secret
22         # attribute.
23         class Shamir
24                 attr_reader :n, :k, :secret, :secret_bitlength, :shares
26                 DEFAULT_SECRET_BITLENGTH = 256
28                 # To create a new SecretSharing::Shamir object, you can
29                 # pass either just n, or k and n.
30                 #
31                 # For example:
32                 #   s = SecretSharing::Shamir.new(5, 3)
33                 # to create an object for 3 out of 5 secret sharing.
34                 #
35                 # or
36                 #   s = SecretSharing::Shamir.new(3)
37                 # for 3 out of 3 secret sharing.
38                 def initialize(n, k=n)
39                         if k > n then
40                                 raise ArgumentError, 'k must be smaller or equal than n'
41                         end     
42                         if k < 2 then
43                                 raise ArgumentError, 'k must be greater or equal to two'
44                         end
45                         if n > 255 then
46                                 raise ArgumentError, 'n must be smaller than 256'
47                         end
48                         @n = n
49                         @k = k
50                         @secret = nil
51                         @shares = []
52                         @received_shares = []
53                 end
55                 # Check whether the secret is set.
56                 def secret_set?
57                         ! @secret.nil?
58                 end
60                 # Create a random secret of a certain bitlength. Returns the
61                 # secret and stores it in the 'secret' attribute.
62                 def create_random_secret(bitlength = DEFAULT_SECRET_BITLENGTH)
63                         raise 'secret already set' if secret_set?
64                         raise 'max bitlength is 1024' if bitlength > 1024
65                         @secret = get_random_number(bitlength)
66                         @secret_bitlength = bitlength
67                         create_shares
68                         @secret
69                 end
71                 # The secret in a password representation (Base64-encoded)
72                 def secret_password
73                         if ! secret_set? then
74                                 raise "Secret not (yet) set."
75                         end
76                         Base64.encode64([@secret.to_s(16)].pack('h*')).split("\n").join
77                 end
79                 # Add a secret share to the object. Accepts either a
80                 # SecretSharing::Shamir::Share instance or a string representing one.
81                 # Returns true if enough shares have been added to recover the secret,
82                 # false otherweise.
83                 def <<(share)
84                         # convert from string if needed
85                         if share.class != SecretSharing::Shamir::Share then
86                                 if share.class == String then
87                                         share = SecretSharing::Shamir::Share.from_string(share)
88                                 else
89                                         raise ArgumentError 'SecretSharing::Shamir::Share ' \
90                                                           + 'or String needed'
91                                 end
92                         end
93                         if @received_shares.include? share then
94                                 raise 'share has already been added'
95                         end
96                         if @received_shares.length == @k then
97                                 raise 'we already have enough shares, no need to add more'
98                         end
99                         @received_shares << share
100                         if @received_shares.length == @k then
101                                 recover_secret
102                                 return true
103                         end
104                         false
105                 end
107                 # Computes the smallest prime of a given bitlength. Uses prime_fasttest
108                 # from the OpenSSL library with 20 attempts to be compatible to openssl
109                 # prime, which is used in the OpenXPKI::Crypto::Secret::Split library.
110                 def self.smallest_prime_of_bitlength(bitlength)
111                         # start with 2^bit_length + 1
112                         test_prime = OpenSSL::BN.new((2**bitlength + 1).to_s)   
113                         prime_found = false
114                         while (! prime_found) do
115                                 # prime_fasttest? 20 do be compatible to
116                                 # openssl prime, which is used in
117                                 # OpenXPKI::Crypto::Secret::Split
118                                 prime_found = test_prime.prime_fasttest? 20
119                                 test_prime += 2
120                         end
121                         test_prime
122                 end
124                 private
125                 # Creates a random number of a certain bitlength, optionally ensuring
126                 # the bitlength by setting the highest bit to 1.
127                 def get_random_number(bitlength, highest_bit_one = true)
128                         byte_length = (bitlength / 8.0).ceil
129                         rand_hex = OpenSSL::Random.random_bytes(byte_length).each_byte. \
130                                                    to_a.map { |a| "%02x" % a }.join('')
131                         rand = OpenSSL::BN.new(rand_hex, 16)
132                         begin
133                                 rand.mask_bits!(bitlength)
134                         rescue OpenSSL::BNError
135                                 # never mind if there was an error, this just means
136                                 # rand was already smaller than 2^bitlength - 1
137                         end
138                         if highest_bit_one then
139                                 rand.set_bit!(bitlength)
140                         end     
141                         rand
142                 end
144                 # Creates the shares by computing random coefficients for a polynomial
145                 # and then computing points on this polynomial.
146                 def create_shares
147                         @coefficients = []
148                         @coefficients[0] = @secret
150                         # round up to next nibble
151                         next_nibble_bitlength = @secret_bitlength + \
152                                                 (4 - (@secret_bitlength % 4))
153                         prime_bitlength = next_nibble_bitlength + 1
154                         @prime = self.class.smallest_prime_of_bitlength(prime_bitlength)
156                         # compute random coefficients
157                         (1..k-1).each do |x|
158                                 @coefficients[x] = get_random_number(@secret_bitlength)
159                         end
161                         (1..n).each do |x|
162                                 @shares[x-1] = construct_share(x, prime_bitlength)
163                         end
164                 end     
166                 # Construct a share by evaluating the polynomial at x and creating
167                 # a SecretSharing::Shamir::Share object.
168                 def construct_share(x, bitlength)
169                         p_x = evaluate_polynomial_at(x)
170                         SecretSharing::Shamir::Share.new(x, p_x, @prime, bitlength)
171                 end
173                 # Evaluate the polynomial at x.
174                 def evaluate_polynomial_at(x)
175                         result = OpenSSL::BN.new('0')
176                         @coefficients.each_with_index do |coeff, i|
177                                 result += coeff * OpenSSL::BN.new(x.to_s)**i
178                                 result %= @prime
179                         end
180                         result
181                 end
183                 # Recover the secret by doing Lagrange interpolation.
184                 def recover_secret
185                         @secret = OpenSSL::BN.new('0')
186                         @received_shares.each do |share|
187                                 l_x = l(share.x, @received_shares)
188                                 summand = share.y * l_x
189                                 summand %= share.prime
190                                 @secret += summand
191                                 @secret %= share.prime
192                         end
193                         @secret
194                 end
195                 
196                 # Part of the Lagrange interpolation.
197                 # This is l_j(0), i.e.
198                 # \prod_{x_j \neq x_i} \frac{-x_i}{x_j - x_i}
199                 # for more information compare Wikipedia:
200                 # http://en.wikipedia.org/wiki/Lagrange_form
201                 def l(x, shares)
202                         (shares.select { |s| s.x != x }.map do |s|
203                                 minus_xi = OpenSSL::BN.new((-s.x).to_s)
204                                 one_over_xj_minus_xi = OpenSSL::BN.new((x - s.x).to_s) \
205                                                        .mod_inverse(shares[0].prime)
206                                 minus_xi.mod_mul(one_over_xj_minus_xi, shares[0].prime)
207                         end.inject { |p, f| p.mod_mul(f, shares[0].prime) })
208                 end
209         end
211         # A SecretSharing::Shamir::Share object represents a share in the
212         # Shamir secret sharing scheme. The share consists of a point (x,y) on
213         # a polynomial over Z/Zp, where p is a prime.
214         class SecretSharing::Shamir::Share
215                 attr_reader :x, :y, :prime_bitlength, :prime
217                 FORMAT_VERSION = '0'
219                 # Create a new share with the given point, prime and prime bitlength.
220                 def initialize(x, y, prime, prime_bitlength)
221                         @x = x
222                         @y = y
223                         @prime = prime
224                         @prime_bitlength = prime_bitlength
225                 end
227                 # Create a new share from a string format representation. For
228                 # a discussion of the format, see the to_s() method.
229                 def self.from_string(string)
230                         version = string[0,1]
231                         if version != '0' then
232                                 raise "invalid share format version #{version}."
233                         end
234                         x = string[1,2].hex
235                         prime_bitlength = 4 * string[-2,2].hex + 1
236                         p_x_str = string[3, string.length - 9]
237                         checksum = string[-6, 4]
238                         computed_checksum = Digest::SHA1.hexdigest(p_x_str)[0,4].upcase
239                         if checksum != computed_checksum then
240                                 raise "invalid checksum. expected #{checksum}, " + \
241                                       "got #{computed_checksum}"
242                         end
243                         prime = SecretSharing::Shamir. \ 
244                                 smallest_prime_of_bitlength(prime_bitlength)
245                         self.new(x, OpenSSL::BN.new(p_x_str, 16), prime, prime_bitlength)
246                 end
248                 # A string representation of the share, that can for example be
249                 # distributed in printed form.
250                 # The string is an uppercase hexadecimal string of the following
251                 # format: ABBC*DDDDEEEE, where
252                 # * A (the first nibble) is the version number of the format, currently
253                 #   fixed to 0.
254                 # * B (the next byte, two hex characters) is the x coordinate of the
255                 #   point on the polynomial.
256                 # * C (the next variable length of bytes) is the y coordinate of the
257                 #   point on the polynomial.
258                 # * D (the next two bytes, four hex characters) is the two highest
259                 #   bytes of the SHA1 hash on the string representing the y coordinate,
260                 #   it is used as a checksum to guard against typos
261                 # * E (the next two bytes, four hex characters) is the bitlength of the
262                 #   prime number in nibbles.
263                 def to_s
264                         # bitlength in nibbles to save space
265                         prime_nibbles = (@prime_bitlength - 1) / 4 
266                         p_x = ("%x" % @y).upcase
267                         FORMAT_VERSION + ("%02x" % @x).upcase \
268                                 + p_x \
269                                 + Digest::SHA1.hexdigest(p_x)[0,4].upcase \
270                                 + ("%02x" % prime_nibbles).upcase
271                 end
273                 # Shares are equal if their string representation is the same.
274                 def ==(share)
275                         share.to_s == self.to_s
276                 end
277         end