root / trunk / sandbox / jhb / oc2 / rsa.py

Revision 322, 11.2 kB (checked in by ocjhb, 3 years ago)

strenthening the random numbers; ordering coins for inspection; changing wording

  • Property svn:mime-type set to text/plain
  • Property svn:eol-style set to native
  • Property svn:executable set to *
Line 
1"""RSA module
2
3This is a module based on the works by Sybren Stuvel, Marloes de Boer and Ivo Tamboer,
4tlslite, helped by Nils Toedtmann, done wrong by Joerg Baach ;-)
5
6This file still needs serious audit before you can trust it for anything productive
7
8"""
9
10# NOTE: Python's modulo can return negative numbers. We compensate for
11# this behaviour using the abs() function
12
13import math
14import sys
15import random    # For picking semi-random numbers
16import types
17from hashlib import sha256
18
19# Get os.urandom PRNG
20import os
21def getRandomBytes(howMany):
22    factor = 8 * 8 #bytesize time security factor
23    bits = howMany * factor
24    if bits % factor:
25        bits = bits+(16-(bits % factor))
26    number = random.getrandbits(bits)
27    bytes =  numberToBytes(number)
28    out = ''
29
30    #Assuming we haven't used a good source of randomness, but the
31    #Mersenne twister, we hash a bit to make it secure
32    while bytes:       
33        out += sha256(bytes[:factor]).digest()
34        bytes = bytes[factor:]
35    return stringToBytes(out[:howMany])
36       
37 
38
39def log(x, base = 10):
40    return math.log(x) / math.log(base)
41
42def gcd(a,b):
43    a, b = max(a,b), min(a,b)
44    while b:
45        a, b = b, a % b
46    return a
47
48def bytes2int(bytes):
49    """Converts a list of bytes or a string to an integer
50
51    >>> (128*256 + 64)*256 + + 15
52    8405007
53    >>> l = [128, 64, 15]
54    >>> bytes2int(l)
55    8405007
56    """
57
58    #if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
59    #    raise TypeError("You must pass a string or a list")
60
61    # Convert byte stream to integer
62    integer = 0
63    for byte in bytes:
64        integer *= 256
65        if type(byte) is types.StringType: byte = ord(byte)
66        integer += byte
67
68    return integer
69
70def int2bytes(number):
71    """Converts a number to a string of bytes
72   
73    >>> bytes2int(int2bytes(123456789))
74    123456789
75    """
76
77    if not (type(number) is types.LongType or type(number) is types.IntType):
78        raise TypeError("You must pass a long or an int")
79
80    string = ""
81
82    while number > 0:
83        string = "%s%s" % (chr(number & 0xFF), string)
84        number /= 256
85   
86    return string
87
88
89
90def ceil(x):
91    """Returns int(math.ceil(x))"""
92
93    return int(math.ceil(x))
94
95
96def rsa_operation(message, ekey, n):
97    """Encrypts a message using encryption key 'ekey', working modulo
98    n"""
99
100    if type(message) is types.IntType:
101        message = long(message)
102    elif type(message) is types.LongType:
103        pass
104    elif type(message) is types.StringType:
105        message = long(bytes2int(message))
106   
107    if not type(message) is types.LongType:
108        raise TypeError("You must pass a long or an int, not %s" % type(message))
109   
110    return pow(message, ekey, n)
111
112
113def blinding_operation(m,secret,n):
114    return (m * secret) % n
115
116
117def encrypt(message, key):
118    """Encrypts a string 'message' with the public key 'key'"""
119    return rsa_operation(message,key['e'],key['n'])
120
121def sign(message, key):
122    """Signs a string 'message' with the private key 'key'"""
123    return rsa_operation(message,key['d'],key['n'])
124
125def decrypt(cypher, key):
126    """Decrypts a cypher with the private key 'key'"""
127    return rsa_operation(cypher,key['d'],key['n'])
128
129def verify(cypher, key):
130    """Verifies a cypher with the public key 'key'"""
131    return rsa_operation(cypher,key['e'],key['n'])
132
133def blind(message,secret,key):
134    return blinding_operation(message,secret,key['n'])
135
136def unblind(message,secret,key):
137    return blinding_operation(message,secret,key['n'])
138
139
140#import math
141
142def bits(integer): #Gets number of bits in integer
143   result = 0
144   while integer:
145      integer >>= 1
146      result += 1
147   return result
148
149
150def invMod(a, b):
151    c, d = a, b
152    uc, ud = 1, 0
153    while c != 0:
154        #This will break when python division changes, but we can't use //
155        #cause of Jython
156        q = d / c
157        c, d = d-(q*c), c
158        uc, ud = ud - (q * uc), uc
159    if d == 1:
160        return ud % b
161    return 0
162
163def getUnblinder(n):
164    while 1:
165        r = getRandomNumber(0,n)
166        if  gcd(r, n) == 1: #relative prime
167            break
168    return r           
169
170
171def generate(bits):  #needed
172    #return (dummypub,dummypriv)
173    p = getRandomPrime(bits/2, False)
174    q = getRandomPrime(bits/2, False)
175    t = (p-1)*(q-1)
176    n = p * q
177   
178    e = 17
179    while 1:
180        if  gcd(e,t) == 1:
181            break
182        e +=2
183   
184    d = invMod(e, t)
185    keys =  ( {'e': e, 'n': n}, {'d': d, 'n': n} )
186    return keys
187
188gen_pubpriv_keys = generate
189
190def getRandomPrime(bits, display=False):
191    if bits < 10:
192        raise AssertionError()
193    #The 1.5 ensures the 2 MSBs are set
194    #Thus, when used for p,q in RSA, n will have its MSB set
195    #
196    #Since 30 is lcm(2,3,5), we'll set our test numbers to
197    #29 % 30 and keep them there
198    low = (2L ** (bits-1)) * 3/2
199    high = 2L ** bits - 30
200    p = getRandomNumber(low, high)
201    p += 29 - (p % 30)
202    while 1:
203        if display: print ".",
204        p += 30
205        if p >= high:
206            p = getRandomNumber(low, high)
207            p += 29 - (p % 30)
208        if isPrime(p, display=display):
209            return p
210
211#Pre-calculate a sieve of the ~100 primes < 1000:
212def makeSieve(n):
213    sieve = range(n)
214    for count in range(2, int(math.sqrt(n))):
215        if sieve[count] == 0:
216            continue
217        x = sieve[count] * 2
218        while x < len(sieve):
219            sieve[x] = 0
220            x += sieve[count]
221    sieve = [x for x in sieve[2:] if x]
222    return sieve
223
224sieve = makeSieve(1000)
225
226def isPrime(n, iterations=5, display=False):
227    #Trial division with sieve
228    for x in sieve:
229        if x >= n: return True
230        if n % x == 0: return False
231    #Passed trial division, proceed to Rabin-Miller
232    #Rabin-Miller implemented per Ferguson & Schneier
233    #Compute s, t for Rabin-Miller
234    if display: print "*",
235    s, t = n-1, 0
236    while s % 2 == 0:
237        s, t = s/2, t+1
238    #Repeat Rabin-Miller x times
239    a = 2 #Use 2 as a base for first iteration speedup, per HAC
240    for count in range(iterations):
241        v = pow(a, s, n)
242        if v==1:
243            continue
244        i = 0
245        while v != n-1:
246            if i == t-1:
247                return False
248            else:
249                v, i = pow(v, 2, n), i+1
250        a = getRandomNumber(2, n)
251    return True
252
253def lcm(a, b):
254    #This will break when python division changes, but we can't use // cause
255    #of Jython
256    return (a * b) / gcd(a, b)
257
258def getRandomNumber(low, high):
259    if low >= high:
260        raise AssertionError()
261    howManyBits = numBits(high)
262    howManyBytes = numBytes(high)
263    lastBits = howManyBits % 8
264    while 1:
265        bytes = getRandomBytes(howManyBytes)
266        if lastBits:
267            bytes[0] = bytes[0] % (1 << lastBits)
268        n = bytesToNumber(bytes)
269        if n >= low and n < high:
270            return n
271
272
273
274def numberToBytes(n):
275    howManyBytes = numBytes(n)
276    bytes = createByteArrayZeros(howManyBytes)
277    for count in range(howManyBytes-1, -1, -1):
278        bytes[count] = int(n % 256)
279        n >>= 8
280    return bytes
281
282def stringToBytes(s):
283    bytes = createByteArrayZeros(0)
284    bytes.fromstring(s)
285    return bytes   
286
287
288import array
289def createByteArraySequence(seq):
290    return array.array('B', seq)
291def createByteArrayZeros(howMany):
292    return array.array('B', [0] * howMany)
293   
294def bytesToNumber(bytes):
295    total = 0L
296    multiplier = 1L
297    for count in range(len(bytes)-1, -1, -1):
298        byte = bytes[count]
299        total += multiplier * byte
300        multiplier *= 256
301    return total
302
303import math
304def numBits(n):
305    if n==0:
306        return 0
307    s = "%x" % n
308    return ((len(s)-1)*4) + \
309    {'0':0, '1':1, '2':2, '3':2,
310     '4':3, '5':3, '6':3, '7':3,
311     '8':4, '9':4, 'a':4, 'b':4,
312     'c':4, 'd':4, 'e':4, 'f':4,
313     }[s[0]]
314    return int(math.floor(math.log(n, 2))+1)
315
316def numBytes(n):
317    if n==0:
318        return 0
319    bits = numBits(n)
320    return int(math.ceil(bits / 8.0))
321
322
323
324(dummypub,dummypriv) = ({'e': 17, 'n': 119854184191709851267115469806947480444279686024088476366095898577590388788441098128656550419021335940507973237749080463551087748856480956399611369963199842509280458397835875930007780781743559815353663167178392850613614395643351736086803768428705593212269700447923498009295529617805433375506688910349264456281L}, {'d': 112803938062785742369049853935950569829910292728553860109266728073026248271473974709323812159078904414595739517881487495106906116570805606023163642318305713478430901878058050836432331457217407052608225335053649488684207226225731712680194298917913785665189224809172156144674031166473829983092121536846612402225L, 'n': 119854184191709851267115469806947480444279686024088476366095898577590388788441098128656550419021335940507973237749080463551087748856480956399611369963199842509280458397835875930007780781743559815353663167178392850613614395643351736086803768428705593212269700447923498009295529617805433375506688910349264456281L})
325
326
327# Do doctest if we're not imported
328if __name__ == "__main__":
329    import time
330    (pub,priv) =  gen_pubpriv_keys(1024)
331    print '=' * 40
332    times = []
333    t = time.time()
334    #blinding
335    #message = 'f'*65
336    message = bytes2int('c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2')
337    #print 'cleartext ', message
338    unblinder = getUnblinder(pub['n'])
339    blinder = pow(invMod(unblinder, pub['n']), pub['e'],pub['n'])
340    times.append(time.time() - t)
341    t = time.time()
342   
343    blinded = blind(message,blinder,pub)
344    times.append(time.time() - t)
345    t = time.time()
346   
347    signedblind = rsa_operation(blinded, priv['d'], priv['n'])
348    times.append(time.time() - t)
349    t = time.time()
350   
351    unblinded = (signedblind * unblinder) % pub['n']
352    times.append(time.time() - t)
353    t = time.time()
354   
355    print 'verifyied', message == verify(unblinded,pub)
356    times.append(time.time() - t)
357    print sum(times) - times[2]
358
359    if 0:
360        #full
361        t = time.time()
362        message = bytes2int('serial '*5)
363        print 'cleartext ', message
364        cypher = encrypt(message,pub)
365        print 'cyphertext: ',cypher
366        print 'decrypted', decrypt(cypher,priv)
367        decrypt(cypher,priv)
368        signed = sign(message,priv)
369        print 'signed', signed
370        print 'verified', message == verify(signed,pub)
371        unblinder = getUnblinder(pub['n'])
372        blinder = pow(invMod(unblinder, pub['n']), pub['e'],pub['n'])
373        blinded = blind(message,blinder,pub)
374        print 'blinded', blinded
375        signedblind = sign(blinded,priv)
376        signedblind = rsa_operation(blinded, priv['d'], priv['n'])
377        print 'signedblind', signedblind
378        unblinded = unblind(signedblind,unblinder,pub)
379        unblinded = (signedblind * unblinder) % pub['n']
380        print 'unblinded', unblinded
381        print 'verified', message == verify(unblinded,pub)
382        print time.time() - t
383       
384       
385        print '=' * 40
386        #no blinding
387        t = time.time()
388        message = 'serial '*5
389        #print 'cleartext ', message
390        cypher = encrypt(message,pub)
391        #print 'cyphertext: ',cypher
392        #print 'decrypted', decrypt(cypher,priv)
393        decrypt(cypher,priv)
394        signed = sign(message,priv)
395        #print 'signed', signed
396        #print 'verified', message == verify(signed,pub)
397        print time.time() - t
398
399     
400
401       
402__all__ = ["gen_pubpriv_keys", "encrypt", "decrypt", "sign", "verify"]
Note: See TracBrowser for help on using the browser.