Ursa Minor
4 minutes to read
We are given the following source code and an instance to connect to:
# Polymero
# Imports
from Crypto.Util.number import isPrime, getPrime, inverse
import hashlib, time, os
# Local import
FLAG = os.environ.get('FLAG').encode()
class URSA:
# Upgraded RSA (faster and with cheap key cycling)
def __init__(self, pbit, lbit):
p, q = self.prime_gen(pbit, lbit)
self.public = {'n': p * q, 'e': 0x10001}
self.private = {'p': p, 'q': q, 'f': (p - 1)*(q - 1), 'd': inverse(self.public['e'], (p - 1)*(q - 1))}
def prime_gen(self, pbit, lbit):
# Smooth primes are FAST primes ~ !
while True:
qlst = [getPrime(lbit) for _ in range(pbit // lbit)]
if len(qlst) - len(set(qlst)) <= 1:
q = 1
for ql in qlst:
q *= ql
Q = 2 * q + 1
if isPrime(Q):
while True:
plst = [getPrime(lbit) for _ in range(pbit // lbit)]
if len(plst) - len(set(plst)) <= 1:
p = 1
for pl in plst:
p *= pl
P = 2 * p + 1
if isPrime(P):
return P, Q
def update_key(self):
# Prime generation is expensive, so we'll just update d and e instead ^w^
self.private['d'] ^= int.from_bytes(hashlib.sha512((str(self.private['d']) + str(time.time())).encode()).digest(), 'big')
self.private['d'] %= self.private['f']
self.public['e'] = inverse(self.private['d'], self.private['f'])
def encrypt(self, m_int):
c_lst = []
while m_int:
c_lst += [pow(m_int, self.public['e'], self.public['n'])]
m_int //= self.public['n']
return c_lst
def decrypt(self, c_int):
m_lst = []
while c_int:
m_lst += [pow(c_int, self.private['d'], self.public['n'])]
c_int //= self.public['n']
return m_lst
# Challenge setup
| ~ Welcome to URSA decryption services
| Press enter to start key generation...""")
| Please hold on while we generate your primes...
oracle = URSA(256, 12)
print("| ~ You are connected to an URSA-256-12 service, public key ::")
print("| id = {}".format(hashlib.sha256(str(oracle.public['n']).encode()).hexdigest()))
print("| e = {}".format(oracle.public['e']))
print("|\n| ~ Here is a free flag sample, enjoy ::")
for i in oracle.encrypt(int.from_bytes(FLAG, 'big')):
print("| {}".format(i))
MENU = """|
| ~ Menu (key updated after {} requests)::
| [E]ncrypt
| [D]ecrypt
| [U]pdate key
| [Q]uit
# Server loop
while True:
if CYCLE % 4:
print(MENU.format(4 - CYCLE))
choice = input("| > ")
choice = 'u'
if choice.lower() == 'e':
msg = int(input("|\n| > (int) "))
print("|\n| ~ Encryption ::")
for i in oracle.encrypt(msg):
print("| {}".format(i))
elif choice.lower() == 'd':
cip = int(input("|\n| > (int) "))
print("|\n| ~ Decryption ::")
for i in oracle.decrypt(cip):
print("| {}".format(i))
elif choice.lower() == 'u':
print("|\n| ~ Key updated succesfully ::")
print("| id = {}".format(hashlib.sha256(str(oracle.public['n']).encode()).hexdigest()))
print("| e = {}".format(oracle.public['e']))
elif choice.lower() == 'q':
print("|\n| ~ Closing services...\n|")
print("|\n| ~ ERROR - Unknown command")
CYCLE += 1
except KeyboardInterrupt:
print("\n| ~ Closing services...\n|")
print("|\n| ~ Please do NOT abuse our services.\n|")
The server will generate two smooth prime numbers $p$ and $q$ (which means that $p - 1$ and $q - 1$ can be factorized easily in small factors). However, we are not provided with $n = p \cdot q$, but with the SHA256 hash of $n$. Moreover, $e = 65537$, as usually. The flag is encrypted with RSA and the generated parameters.
After that, we are given the opportunity to encrypt messages, decrypt ciphertexts or update the key.
I noticed that there is something weird in the encrypt
def encrypt(self, m_int):
c_lst = []
while m_int:
c_lst += [pow(m_int, self.public['e'], self.public['n'])]
m_int //= self.public['n']
return c_lst
First af all, the message must be introduced as a decimal number. Furthermore, if the number is bigger than $n$, the function will perform more than one iteration of the while
loop. Hence, we have a way to obtain $n$ by doing Binary Search:
We enter a number $x$. If $x \geq n$, then the server will send more than one ciphertext; and if $x < n$, the server will reply with a single ciphertext. Therefore, we can set two limits (for example, $2^{256}$ and $2^{512}$) and branch the interval with Binary Search until we get $n$:
a, b = 2 ** 256, 2 ** 512
while a + 1 != b:
test_n = (a + b) // 2
if len(encrypt(r, test_n)) > 1:
b = test_n
a = test_n
n = a if a % 2 else b
print(n, a, b)
assert hashlib.sha256(str(n).encode()).hexdigest() == n_id_hex
To verify it, we can assert that the SHA256 hash of $n$ matches the one sent by the server at the beginning.
Once we have $n$, we can use Pollard’s p - 1 algorithm to factor it and then decrypt the flag as in most RSA challenges:
p, q = pollard_p_1(n)
assert n == p * q
phi_n = (p - 1) * (q - 1)
d = pow(e, -1, phi_n)
m = pow(flag_enc, d, n)
So we can get the flag:
$ python3 solve.py blackhat2-09afaf950bafc7bc0c7c3d69fcaeb7df-0.chals.bh.ctf.sa
[+] Opening connection to blackhat2-09afaf950bafc7bc0c7c3d69fcaeb7df-0.chals.bh.ctf.sa on port 443: Done
[*] Closed connection to blackhat2-09afaf950bafc7bc0c7c3d69fcaeb7df-0.chals.bh.ctf.sa port 443
The full script can be found in here: solve.py