PicoCTF: SRA Challenge

PicoCTF: SRA Challenge

Guillaume Bethouart
|
reading-time
|
Back to all articles
SHARE

Here's the write-up for SRA challenge of the PicoCTF 2023, running from March 14 to March 28. These challenge is part of the Cryptography category.

The challenge

The challenge description is the following:

I just recently learnt about the SRA public key cryptosystem... or wait, was it supposed to be RSA? Hmmm, I should probably check... Connect to the program on our server: nc saturn.picoctf.net 60330 Download the program: chal.py

No hint is given for this challenge.

Code analysis

The server should run the chal.py. To well understand it, we can first change the variable names:

from Crypto.Util.number import getPrime, inverse, bytes_to_long
from string import ascii_letters, digits
from random import choice
import time
time.clock = time.time

plain = "".join(choice(ascii_letters + digits) for _ in range(16))  # pride
p = getPrime(128)  # gluttony
q = getPrime(128)  # greed

# Use fixed values, easier for demonstration purpose
p, q = 321315087699876844450938908960149703981, 236046326807786585102625365485104504817

n = p * q  # lust
e = 65537  # sloth
phi = (p - 1) * (q - 1)
d = inverse(e, phi)  # envy
cipher = pow(bytes_to_long(plain.encode()), e, n)  # anger

Which is a RSA-256 key-pair generation. Then the server returns the cipher and d.

To obtain the flag, we must compute the plain from the cipher and d. But to do that, we need n ...

Compute n from d and e

,  is an RSA Key so that  with .

is the public but here unknown modulo.

From key equation we derive:

So if we try to factor  we will get decomposition of ,  multiplied by an unknown .

For complementary explanations, you can read this nice crypto.stackexchange post.

Factors decomposition

Let us start with the factorization of . To do that, you can use the primefac module, which is a bit slow, use Sage or this very fast website. I use this later solution and then use a small script to parse the website output.

dem1 = d * e - 1
dem1

1845238994787093868829071101511950464175922457296917469750101517162126249781286720

factors_string = '2^6 × 3^2 × 5 × 587 × 24329 × 129631 × 407807 × 80787 180953 × 1241 605438 156271 × 84761 673627 946307 × 99817 949863 851199'
factors = []
for factor in factors_string.split(' × '):
   factor = factor.replace(' ', '').strip()
   if '^' in factor:
       factor = factor.split('^')
       factors += [int(factor[0])] * int(factor[1])
   else:
       factors.append(int(factor))

factors = sorted(factors, reverse=True)
print(factors)

[99817949863851199, 84761673627946307, 1241605438156271, 80787180953, 407807, 129631, 24329, 587, 5, 3, 3, 2, 2, 2, 2, 2, 2]

Because  and  are 128-bit primes,  bit size is at most 256 bits.
Then size of  compared to the size of  gives us  the maximum bits size of .

We have a set of prime numbers, and from this set we need to find all subsets such that the product of factors in these subsets correspond to  with a suitable bit size.

More formally, with , we search all  built from subsets of factors  , such that  and .

So, thinking in bit-size we see that we are facing a subset sum problem instance.

Well, let’s convert the finding of product to a subset sum problem …

Subset sum problem

The subset sum problem is a well known problem.

The following code comes from stackoverflow topic on the subsets sum problem solving, updated for our purpose.

def get_subsets_sum(data, target, epsilon):
   """Find sums of values in `data` which gives `target` ± `epsilon`."""
   subsets, subsets_primes = [], []

   differences = {}
   for value in data:
       prospects = []
       for diff in differences:
           if (diff >= value) and abs(value - diff) < epsilon:
               new_subset = [value] + differences[diff]
               new_subset.sort()
               if new_subset not in subsets:
                   subsets.append(new_subset)
           if value - diff < 0.:
               prospects.append((value, diff))

       # update the differences record
       for prospect in prospects:
           new_diff = target - sum(differences[prospect[1]]) - prospect[0]
           differences[new_diff] = differences[prospect[1]] + [prospect[0]]
       differences[target - value] = [value]
   return subsets

And we can use it to find a subset for product using the bit size of the factors:

import math

def get_subsets_product(data, target, epsilon):
   """Find products of values in `data` which gives `target` ± `epsilon` (`epsilon` is given as bit size)."""  
   data = sorted(data, reverse=True)
   candidates = []
   bit_sizes = [math.log2(value) for value in data]
   target_bit_size = math.log2(target)
   bit_sizes_substets = get_subsets_sum(bit_sizes, target_bit_size, epsilon)
   for subset in bit_sizes_substets:
       values_subset = [data[bit_sizes.index(bit_size)] for bit_size in subset]
       candidates.append(math.prod(values_subset))
   return candidates

Find k candidates

We can first reduce the list of possibles factors for  according to its size.

k_size_max = int(math.log2(dem1) - math.log2(cipher))
k_size_max

17

k_factors = list(filter(lambda factor: math.log2(factor) < k_size_max, factors))
k_factors

[129631, 24329, 587, 5, 3, 3, 2, 2, 2, 2, 2, 2]

k_candidates = get_subsets_product(k_factors, 2**17, 1)
k_candidates

[121645, 72987, 97316, 105660, 70440, 84528, 93920, 112704]

In the following, we will need to remove the factors of  from the list of  factors. Let's build a small function to do it:

def remove_factors(factors, k_value, k_factors):
   factors = sorted(factors, reverse=True)
   k_factors = sorted(k_factors, reverse=True)
   result = factors.copy()
   for k_factor in k_factors:
       if k_value % k_factor == 0:
           k_value = k_value // k_factor
           result.remove(k_factor)
   return result

Factorize p (or q)

For each  candidate, we can now try to find products of remaining factors (i.e. the original list of factors without the  candidate factors), that should correspond to  or , with the following constraints:

  • bit size <= 128
  • this product + 1 is a prime number (should be p or q)

import primefac

porqm1_candidates = set()

for k_candidate in k_candidates:
   remaining_factors = remove_factors(factors, k_candidate, k_factors)
   porqm1_candidate_set = get_subsets_product(remaining_factors, 2**128, 5)
   for porqm1_candidate in porqm1_candidate_set:
       if primefac.isprime(porqm1_candidate + 1):
           porqm1_candidates.add(porqm1_candidate)
           
porqm1_candidates = list(porqm1_candidates)
porqm1_candidates

[110451235237789359826286740748914534620,
147528954254866615689140853428190315510,
11582308472669131473870323426904724338,
29505790850973323137828170685638063102,
34919787361941703640524468294085435040,
236046326807786585102625365485104504816,
18091216071735087475943777451592941846,
321315087699876844450938908960149703980,
187582124573757824249206295828661487320]

Build the n candidates and find the correct one

In case when the resulting set is longer than two, we can brute force the remaining combinations.

for idx, pm1 in enumerate(porqm1_candidates):
   for qm1 in porqm1_candidates[idx + 1:]:
       n_candidate = (pm1 + 1) * (qm1 + 1)
       possible_plain = pow(cipher, d, n_candidate)
       possible_plain = possible_plain.to_bytes(length=256, byteorder='big')
       try:
           possible_plain = possible_plain.decode('ascii')
       except UnicodeDecodeError:
           continue
       print(possible_plain)

ke1CBDLfTKQtEDCw

Well done, there is only one solution !

Thanks to

the PicoCTF team for providing this nice CTF. I learned a lot in many domains I'm not familiar with. I can't wait to read the write-up for the challenges I haven't solved.