| 1 | """RSA module |
|---|
| 2 | |
|---|
| 3 | This is a module based on the works by Sybren Stuvel, Marloes de Boer and Ivo Tamboer, |
|---|
| 4 | tlslite, helped by Nils Toedtmann, done wrong by Joerg Baach ;-) |
|---|
| 5 | |
|---|
| 6 | This file still needs serious audit before you can trust it for anything productive |
|---|
| 7 | |
|---|
| 8 | """ |
|---|
| 9 | |
|---|
| 10 | # NOTE: Python's modulo can return negative numbers. We compensate for |
|---|
| 11 | # this behaviour using the abs() function |
|---|
| 12 | |
|---|
| 13 | import math |
|---|
| 14 | import sys |
|---|
| 15 | import random # For picking semi-random numbers |
|---|
| 16 | import types |
|---|
| 17 | from hashlib import sha256 |
|---|
| 18 | |
|---|
| 19 | # Get os.urandom PRNG |
|---|
| 20 | import os |
|---|
| 21 | def getRandomBytes(howMany): |
|---|
| 22 | factor = 8 * 8 #bytesize time security factor |
|---|
| 23 | bits = howMany * factor |
|---|
| 24 | if bits % factor: |
|---|
| 25 | bits = bits+(16-(bits % factor)) |
|---|
| 26 | number = random.getrandbits(bits) |
|---|
| 27 | bytes = numberToBytes(number) |
|---|
| 28 | out = '' |
|---|
| 29 | |
|---|
| 30 | #Assuming we haven't used a good source of randomness, but the |
|---|
| 31 | #Mersenne twister, we hash a bit to make it secure |
|---|
| 32 | while bytes: |
|---|
| 33 | out += sha256(bytes[:factor]).digest() |
|---|
| 34 | bytes = bytes[factor:] |
|---|
| 35 | return stringToBytes(out[:howMany]) |
|---|
| 36 | |
|---|
| 37 | |
|---|
| 38 | |
|---|
| 39 | def log(x, base = 10): |
|---|
| 40 | return math.log(x) / math.log(base) |
|---|
| 41 | |
|---|
| 42 | def gcd(a,b): |
|---|
| 43 | a, b = max(a,b), min(a,b) |
|---|
| 44 | while b: |
|---|
| 45 | a, b = b, a % b |
|---|
| 46 | return a |
|---|
| 47 | |
|---|
| 48 | def bytes2int(bytes): |
|---|
| 49 | """Converts a list of bytes or a string to an integer |
|---|
| 50 | |
|---|
| 51 | >>> (128*256 + 64)*256 + + 15 |
|---|
| 52 | 8405007 |
|---|
| 53 | >>> l = [128, 64, 15] |
|---|
| 54 | >>> bytes2int(l) |
|---|
| 55 | 8405007 |
|---|
| 56 | """ |
|---|
| 57 | |
|---|
| 58 | #if not (type(bytes) is types.ListType or type(bytes) is types.StringType): |
|---|
| 59 | # raise TypeError("You must pass a string or a list") |
|---|
| 60 | |
|---|
| 61 | # Convert byte stream to integer |
|---|
| 62 | integer = 0 |
|---|
| 63 | for byte in bytes: |
|---|
| 64 | integer *= 256 |
|---|
| 65 | if type(byte) is types.StringType: byte = ord(byte) |
|---|
| 66 | integer += byte |
|---|
| 67 | |
|---|
| 68 | return integer |
|---|
| 69 | |
|---|
| 70 | def int2bytes(number): |
|---|
| 71 | """Converts a number to a string of bytes |
|---|
| 72 | |
|---|
| 73 | >>> bytes2int(int2bytes(123456789)) |
|---|
| 74 | 123456789 |
|---|
| 75 | """ |
|---|
| 76 | |
|---|
| 77 | if not (type(number) is types.LongType or type(number) is types.IntType): |
|---|
| 78 | raise TypeError("You must pass a long or an int") |
|---|
| 79 | |
|---|
| 80 | string = "" |
|---|
| 81 | |
|---|
| 82 | while number > 0: |
|---|
| 83 | string = "%s%s" % (chr(number & 0xFF), string) |
|---|
| 84 | number /= 256 |
|---|
| 85 | |
|---|
| 86 | return string |
|---|
| 87 | |
|---|
| 88 | |
|---|
| 89 | |
|---|
| 90 | def ceil(x): |
|---|
| 91 | """Returns int(math.ceil(x))""" |
|---|
| 92 | |
|---|
| 93 | return int(math.ceil(x)) |
|---|
| 94 | |
|---|
| 95 | |
|---|
| 96 | def rsa_operation(message, ekey, n): |
|---|
| 97 | """Encrypts a message using encryption key 'ekey', working modulo |
|---|
| 98 | n""" |
|---|
| 99 | |
|---|
| 100 | if type(message) is types.IntType: |
|---|
| 101 | message = long(message) |
|---|
| 102 | elif type(message) is types.LongType: |
|---|
| 103 | pass |
|---|
| 104 | elif type(message) is types.StringType: |
|---|
| 105 | message = long(bytes2int(message)) |
|---|
| 106 | |
|---|
| 107 | if not type(message) is types.LongType: |
|---|
| 108 | raise TypeError("You must pass a long or an int, not %s" % type(message)) |
|---|
| 109 | |
|---|
| 110 | return pow(message, ekey, n) |
|---|
| 111 | |
|---|
| 112 | |
|---|
| 113 | def blinding_operation(m,secret,n): |
|---|
| 114 | return (m * secret) % n |
|---|
| 115 | |
|---|
| 116 | |
|---|
| 117 | def encrypt(message, key): |
|---|
| 118 | """Encrypts a string 'message' with the public key 'key'""" |
|---|
| 119 | return rsa_operation(message,key['e'],key['n']) |
|---|
| 120 | |
|---|
| 121 | def sign(message, key): |
|---|
| 122 | """Signs a string 'message' with the private key 'key'""" |
|---|
| 123 | return rsa_operation(message,key['d'],key['n']) |
|---|
| 124 | |
|---|
| 125 | def decrypt(cypher, key): |
|---|
| 126 | """Decrypts a cypher with the private key 'key'""" |
|---|
| 127 | return rsa_operation(cypher,key['d'],key['n']) |
|---|
| 128 | |
|---|
| 129 | def verify(cypher, key): |
|---|
| 130 | """Verifies a cypher with the public key 'key'""" |
|---|
| 131 | return rsa_operation(cypher,key['e'],key['n']) |
|---|
| 132 | |
|---|
| 133 | def blind(message,secret,key): |
|---|
| 134 | return blinding_operation(message,secret,key['n']) |
|---|
| 135 | |
|---|
| 136 | def unblind(message,secret,key): |
|---|
| 137 | return blinding_operation(message,secret,key['n']) |
|---|
| 138 | |
|---|
| 139 | |
|---|
| 140 | #import math |
|---|
| 141 | |
|---|
| 142 | def bits(integer): #Gets number of bits in integer |
|---|
| 143 | result = 0 |
|---|
| 144 | while integer: |
|---|
| 145 | integer >>= 1 |
|---|
| 146 | result += 1 |
|---|
| 147 | return result |
|---|
| 148 | |
|---|
| 149 | |
|---|
| 150 | def invMod(a, b): |
|---|
| 151 | c, d = a, b |
|---|
| 152 | uc, ud = 1, 0 |
|---|
| 153 | while c != 0: |
|---|
| 154 | #This will break when python division changes, but we can't use // |
|---|
| 155 | #cause of Jython |
|---|
| 156 | q = d / c |
|---|
| 157 | c, d = d-(q*c), c |
|---|
| 158 | uc, ud = ud - (q * uc), uc |
|---|
| 159 | if d == 1: |
|---|
| 160 | return ud % b |
|---|
| 161 | return 0 |
|---|
| 162 | |
|---|
| 163 | def getUnblinder(n): |
|---|
| 164 | while 1: |
|---|
| 165 | r = getRandomNumber(0,n) |
|---|
| 166 | if gcd(r, n) == 1: #relative prime |
|---|
| 167 | break |
|---|
| 168 | return r |
|---|
| 169 | |
|---|
| 170 | |
|---|
| 171 | def generate(bits): #needed |
|---|
| 172 | #return (dummypub,dummypriv) |
|---|
| 173 | p = getRandomPrime(bits/2, False) |
|---|
| 174 | q = getRandomPrime(bits/2, False) |
|---|
| 175 | t = (p-1)*(q-1) |
|---|
| 176 | n = p * q |
|---|
| 177 | |
|---|
| 178 | e = 17 |
|---|
| 179 | while 1: |
|---|
| 180 | if gcd(e,t) == 1: |
|---|
| 181 | break |
|---|
| 182 | e +=2 |
|---|
| 183 | |
|---|
| 184 | d = invMod(e, t) |
|---|
| 185 | keys = ( {'e': e, 'n': n}, {'d': d, 'n': n} ) |
|---|
| 186 | return keys |
|---|
| 187 | |
|---|
| 188 | gen_pubpriv_keys = generate |
|---|
| 189 | |
|---|
| 190 | def getRandomPrime(bits, display=False): |
|---|
| 191 | if bits < 10: |
|---|
| 192 | raise AssertionError() |
|---|
| 193 | #The 1.5 ensures the 2 MSBs are set |
|---|
| 194 | #Thus, when used for p,q in RSA, n will have its MSB set |
|---|
| 195 | # |
|---|
| 196 | #Since 30 is lcm(2,3,5), we'll set our test numbers to |
|---|
| 197 | #29 % 30 and keep them there |
|---|
| 198 | low = (2L ** (bits-1)) * 3/2 |
|---|
| 199 | high = 2L ** bits - 30 |
|---|
| 200 | p = getRandomNumber(low, high) |
|---|
| 201 | p += 29 - (p % 30) |
|---|
| 202 | while 1: |
|---|
| 203 | if display: print ".", |
|---|
| 204 | p += 30 |
|---|
| 205 | if p >= high: |
|---|
| 206 | p = getRandomNumber(low, high) |
|---|
| 207 | p += 29 - (p % 30) |
|---|
| 208 | if isPrime(p, display=display): |
|---|
| 209 | return p |
|---|
| 210 | |
|---|
| 211 | #Pre-calculate a sieve of the ~100 primes < 1000: |
|---|
| 212 | def makeSieve(n): |
|---|
| 213 | sieve = range(n) |
|---|
| 214 | for count in range(2, int(math.sqrt(n))): |
|---|
| 215 | if sieve[count] == 0: |
|---|
| 216 | continue |
|---|
| 217 | x = sieve[count] * 2 |
|---|
| 218 | while x < len(sieve): |
|---|
| 219 | sieve[x] = 0 |
|---|
| 220 | x += sieve[count] |
|---|
| 221 | sieve = [x for x in sieve[2:] if x] |
|---|
| 222 | return sieve |
|---|
| 223 | |
|---|
| 224 | sieve = makeSieve(1000) |
|---|
| 225 | |
|---|
| 226 | def isPrime(n, iterations=5, display=False): |
|---|
| 227 | #Trial division with sieve |
|---|
| 228 | for x in sieve: |
|---|
| 229 | if x >= n: return True |
|---|
| 230 | if n % x == 0: return False |
|---|
| 231 | #Passed trial division, proceed to Rabin-Miller |
|---|
| 232 | #Rabin-Miller implemented per Ferguson & Schneier |
|---|
| 233 | #Compute s, t for Rabin-Miller |
|---|
| 234 | if display: print "*", |
|---|
| 235 | s, t = n-1, 0 |
|---|
| 236 | while s % 2 == 0: |
|---|
| 237 | s, t = s/2, t+1 |
|---|
| 238 | #Repeat Rabin-Miller x times |
|---|
| 239 | a = 2 #Use 2 as a base for first iteration speedup, per HAC |
|---|
| 240 | for count in range(iterations): |
|---|
| 241 | v = pow(a, s, n) |
|---|
| 242 | if v==1: |
|---|
| 243 | continue |
|---|
| 244 | i = 0 |
|---|
| 245 | while v != n-1: |
|---|
| 246 | if i == t-1: |
|---|
| 247 | return False |
|---|
| 248 | else: |
|---|
| 249 | v, i = pow(v, 2, n), i+1 |
|---|
| 250 | a = getRandomNumber(2, n) |
|---|
| 251 | return True |
|---|
| 252 | |
|---|
| 253 | def lcm(a, b): |
|---|
| 254 | #This will break when python division changes, but we can't use // cause |
|---|
| 255 | #of Jython |
|---|
| 256 | return (a * b) / gcd(a, b) |
|---|
| 257 | |
|---|
| 258 | def getRandomNumber(low, high): |
|---|
| 259 | if low >= high: |
|---|
| 260 | raise AssertionError() |
|---|
| 261 | howManyBits = numBits(high) |
|---|
| 262 | howManyBytes = numBytes(high) |
|---|
| 263 | lastBits = howManyBits % 8 |
|---|
| 264 | while 1: |
|---|
| 265 | bytes = getRandomBytes(howManyBytes) |
|---|
| 266 | if lastBits: |
|---|
| 267 | bytes[0] = bytes[0] % (1 << lastBits) |
|---|
| 268 | n = bytesToNumber(bytes) |
|---|
| 269 | if n >= low and n < high: |
|---|
| 270 | return n |
|---|
| 271 | |
|---|
| 272 | |
|---|
| 273 | |
|---|
| 274 | def numberToBytes(n): |
|---|
| 275 | howManyBytes = numBytes(n) |
|---|
| 276 | bytes = createByteArrayZeros(howManyBytes) |
|---|
| 277 | for count in range(howManyBytes-1, -1, -1): |
|---|
| 278 | bytes[count] = int(n % 256) |
|---|
| 279 | n >>= 8 |
|---|
| 280 | return bytes |
|---|
| 281 | |
|---|
| 282 | def stringToBytes(s): |
|---|
| 283 | bytes = createByteArrayZeros(0) |
|---|
| 284 | bytes.fromstring(s) |
|---|
| 285 | return bytes |
|---|
| 286 | |
|---|
| 287 | |
|---|
| 288 | import array |
|---|
| 289 | def createByteArraySequence(seq): |
|---|
| 290 | return array.array('B', seq) |
|---|
| 291 | def createByteArrayZeros(howMany): |
|---|
| 292 | return array.array('B', [0] * howMany) |
|---|
| 293 | |
|---|
| 294 | def bytesToNumber(bytes): |
|---|
| 295 | total = 0L |
|---|
| 296 | multiplier = 1L |
|---|
| 297 | for count in range(len(bytes)-1, -1, -1): |
|---|
| 298 | byte = bytes[count] |
|---|
| 299 | total += multiplier * byte |
|---|
| 300 | multiplier *= 256 |
|---|
| 301 | return total |
|---|
| 302 | |
|---|
| 303 | import math |
|---|
| 304 | def numBits(n): |
|---|
| 305 | if n==0: |
|---|
| 306 | return 0 |
|---|
| 307 | s = "%x" % n |
|---|
| 308 | return ((len(s)-1)*4) + \ |
|---|
| 309 | {'0':0, '1':1, '2':2, '3':2, |
|---|
| 310 | '4':3, '5':3, '6':3, '7':3, |
|---|
| 311 | '8':4, '9':4, 'a':4, 'b':4, |
|---|
| 312 | 'c':4, 'd':4, 'e':4, 'f':4, |
|---|
| 313 | }[s[0]] |
|---|
| 314 | return int(math.floor(math.log(n, 2))+1) |
|---|
| 315 | |
|---|
| 316 | def numBytes(n): |
|---|
| 317 | if n==0: |
|---|
| 318 | return 0 |
|---|
| 319 | bits = numBits(n) |
|---|
| 320 | return int(math.ceil(bits / 8.0)) |
|---|
| 321 | |
|---|
| 322 | |
|---|
| 323 | |
|---|
| 324 | (dummypub,dummypriv) = ({'e': 17, 'n': 119854184191709851267115469806947480444279686024088476366095898577590388788441098128656550419021335940507973237749080463551087748856480956399611369963199842509280458397835875930007780781743559815353663167178392850613614395643351736086803768428705593212269700447923498009295529617805433375506688910349264456281L}, {'d': 112803938062785742369049853935950569829910292728553860109266728073026248271473974709323812159078904414595739517881487495106906116570805606023163642318305713478430901878058050836432331457217407052608225335053649488684207226225731712680194298917913785665189224809172156144674031166473829983092121536846612402225L, 'n': 119854184191709851267115469806947480444279686024088476366095898577590388788441098128656550419021335940507973237749080463551087748856480956399611369963199842509280458397835875930007780781743559815353663167178392850613614395643351736086803768428705593212269700447923498009295529617805433375506688910349264456281L}) |
|---|
| 325 | |
|---|
| 326 | |
|---|
| 327 | # Do doctest if we're not imported |
|---|
| 328 | if __name__ == "__main__": |
|---|
| 329 | import time |
|---|
| 330 | (pub,priv) = gen_pubpriv_keys(1024) |
|---|
| 331 | print '=' * 40 |
|---|
| 332 | times = [] |
|---|
| 333 | t = time.time() |
|---|
| 334 | #blinding |
|---|
| 335 | #message = 'f'*65 |
|---|
| 336 | message = bytes2int('c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2') |
|---|
| 337 | #print 'cleartext ', message |
|---|
| 338 | unblinder = getUnblinder(pub['n']) |
|---|
| 339 | blinder = pow(invMod(unblinder, pub['n']), pub['e'],pub['n']) |
|---|
| 340 | times.append(time.time() - t) |
|---|
| 341 | t = time.time() |
|---|
| 342 | |
|---|
| 343 | blinded = blind(message,blinder,pub) |
|---|
| 344 | times.append(time.time() - t) |
|---|
| 345 | t = time.time() |
|---|
| 346 | |
|---|
| 347 | signedblind = rsa_operation(blinded, priv['d'], priv['n']) |
|---|
| 348 | times.append(time.time() - t) |
|---|
| 349 | t = time.time() |
|---|
| 350 | |
|---|
| 351 | unblinded = (signedblind * unblinder) % pub['n'] |
|---|
| 352 | times.append(time.time() - t) |
|---|
| 353 | t = time.time() |
|---|
| 354 | |
|---|
| 355 | print 'verifyied', message == verify(unblinded,pub) |
|---|
| 356 | times.append(time.time() - t) |
|---|
| 357 | print sum(times) - times[2] |
|---|
| 358 | |
|---|
| 359 | if 0: |
|---|
| 360 | #full |
|---|
| 361 | t = time.time() |
|---|
| 362 | message = bytes2int('serial '*5) |
|---|
| 363 | print 'cleartext ', message |
|---|
| 364 | cypher = encrypt(message,pub) |
|---|
| 365 | print 'cyphertext: ',cypher |
|---|
| 366 | print 'decrypted', decrypt(cypher,priv) |
|---|
| 367 | decrypt(cypher,priv) |
|---|
| 368 | signed = sign(message,priv) |
|---|
| 369 | print 'signed', signed |
|---|
| 370 | print 'verified', message == verify(signed,pub) |
|---|
| 371 | unblinder = getUnblinder(pub['n']) |
|---|
| 372 | blinder = pow(invMod(unblinder, pub['n']), pub['e'],pub['n']) |
|---|
| 373 | blinded = blind(message,blinder,pub) |
|---|
| 374 | print 'blinded', blinded |
|---|
| 375 | signedblind = sign(blinded,priv) |
|---|
| 376 | signedblind = rsa_operation(blinded, priv['d'], priv['n']) |
|---|
| 377 | print 'signedblind', signedblind |
|---|
| 378 | unblinded = unblind(signedblind,unblinder,pub) |
|---|
| 379 | unblinded = (signedblind * unblinder) % pub['n'] |
|---|
| 380 | print 'unblinded', unblinded |
|---|
| 381 | print 'verified', message == verify(unblinded,pub) |
|---|
| 382 | print time.time() - t |
|---|
| 383 | |
|---|
| 384 | |
|---|
| 385 | print '=' * 40 |
|---|
| 386 | #no blinding |
|---|
| 387 | t = time.time() |
|---|
| 388 | message = 'serial '*5 |
|---|
| 389 | #print 'cleartext ', message |
|---|
| 390 | cypher = encrypt(message,pub) |
|---|
| 391 | #print 'cyphertext: ',cypher |
|---|
| 392 | #print 'decrypted', decrypt(cypher,priv) |
|---|
| 393 | decrypt(cypher,priv) |
|---|
| 394 | signed = sign(message,priv) |
|---|
| 395 | #print 'signed', signed |
|---|
| 396 | #print 'verified', message == verify(signed,pub) |
|---|
| 397 | print time.time() - t |
|---|
| 398 | |
|---|
| 399 | |
|---|
| 400 | |
|---|
| 401 | |
|---|
| 402 | __all__ = ["gen_pubpriv_keys", "encrypt", "decrypt", "sign", "verify"] |
|---|