RSACBC
6 minutes to read
We are given the Python source code of the server that encrypts the flag:
import os,json
from Crypto.Util.number import getPrime, bytes_to_long as b2l, long_to_bytes as l2b
from Crypto.Util.Padding import pad
class RSA_CBC:
def __init__(self, p,q):
self.n = p*q
self.e = 0x10001
self.p = p
self.q = q
self.d = pow(self.e, -1,(p-1)*(q-1))
self.BLOCK_LENGTH = 32
def xor(self,a,b):
return bytes([x^y for x,y in zip(a,b)])
def _encrypt(self, m):
return l2b(pow(b2l(m), self.e, self.n))
def bytes2blocks(self, m):
return [m[i:i+self.BLOCK_LENGTH] for i in range(0, len(m), self.BLOCK_LENGTH)]
def encrypt(self,m):
blocks = self.bytes2blocks(pad(m,self.BLOCK_LENGTH))
enc = [int.to_bytes(self.p,32,"big")]
for i in range(len(blocks)):
enc.append(self._encrypt(self.xor(enc[i],blocks[i])))
return {"blocks":list(map(bytes.hex,enc[1:]))}
FLAG = os.getenv("FLAG", "HackOn{AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA}")
p = getPrime(256)
q = getPrime(256)
cipher = RSA_CBC(p,q)
print(f"n = {cipher.n}")
menu = """
1. Encrypt
2. ¿Flag?
3. Exit
"""
for _ in range(2):
print(menu)
option = int(input(">>> "))
if option == 1:
try:
m = bytes.fromhex(input("Message: "))
except:
raise ValueError("Whaatt...")
if len(m) > 192 or b"\x00"*32 in m:
raise ValueError("Try harder...")
print(json.dumps(cipher.encrypt(m)))
elif option == 2:
print(json.dumps(cipher.encrypt(FLAG.encode())))
break
else:
exit()
Source code analysis
As the name of the challenge suggests, it is using RSA in CBC mode to encrypt information:
def encrypt(self,m):
blocks = self.bytes2blocks(pad(m,self.BLOCK_LENGTH))
enc = [int.to_bytes(self.p,32,"big")]
for i in range(len(blocks)):
enc.append(self._encrypt(self.xor(enc[i],blocks[i])))
return {"blocks":list(map(bytes.hex,enc[1:]))}
That is, it takes a message in bytes, divides into 32-byte blocks and encrypts each block using RSA. The CBC mode adds a XOR operation between the ciphertext block and the next plaintext block, and an initialization vector (IV) for the first plaintext block. An image is worth a thousand words, just take RSA as “block cipher encryption”:
We have two options, but we only have two opportunities to query the server:
for _ in range(2):
print(menu)
option = int(input(">>> "))
if option == 1:
try:
m = bytes.fromhex(input("Message: "))
except:
raise ValueError("Whaatt...")
if len(m) > 192 or b"\x00"*32 in m:
raise ValueError("Try harder...")
print(json.dumps(cipher.encrypt(m)))
elif option == 2:
print(json.dumps(cipher.encrypt(FLAG.encode())))
break
else:
exit()
Basically, we can encrypt an arbitrary message (option 1) or get the encrypted flag (option 2).
Last but not least, the RSA parameters are secure enough:
p = getPrime(256)
q = getPrime(256)
cipher = RSA_CBC(p,q)
print(f"n = {cipher.n}")
Solution
Since we wanna decrypt RSA, we need to find a way to factor the public modulus
enc = [int.to_bytes(self.p,32,"big")]
So, what if we send a message that consists of only 32 null bytes? the result should be the value of
if len(m) > 192 or b"\x00"*32 in m:
raise ValueError("Try harder...")
So, let’s keep things simple. What happens if we send 31 null bytes and one \x01
? This is the same as having
This way, we will be getting
Proof
Let’s forget about
We have the above result because
Now it is clear that the above expression is a multiple of
This is just a division algorithm, with
And the above shows that
Observation
We could have also used a XOR with
Implementation
Now, the solution code should be easy to follow:
io = get_process()
io.recvuntil(b'n = ')
n = int(io.recvline())
io.sendlineafter(b'>>> ', b'1')
io.sendlineafter(b'Message: ', (b'\0' * 31 + b'\x01').hex().encode())
kp_minus_1 = int(json.loads(io.recvline().decode()).get('blocks', [])[0], 16)
p = gcd(kp_minus_1 + 1, n)
assert p.bit_length() == 256
io.info(f'{p = }')
q = n // p
d = pow(0x10001, -1, (p - 1) * (q - 1))
Once we have this, we are able to decrypt RSA, we only need to take into account that is using CBC mode, so we must XOR each block accordingly and only consider 32-byte blocks (this is 512-bit RSA, so ciphertexts are 64 bytes long, but plaintext messages are only 32 bytes long):
io.sendlineafter(b'>>> ', b'2')
blocks = json.loads(io.recvline().decode()).get('blocks', [])
flag = []
prev_block = p.to_bytes(32, 'big')
for b in blocks:
m = pow(int(b, 16), d, n)
flag.append(xor(m.to_bytes(32, 'big'), prev_block)[:32])
prev_block = bytes.fromhex(b)
io.success(unpad(b''.join(flag), 32).decode())
Flag
At this point, we can capture the flag:
$ python3 solve.py 0.cloud.chals.io 18923
[+] Opening connection to 0.cloud.chals.io on port 18923: Done
[*] p = 90741731438160432997006881358404031084898961947119721095507310157558171448997
[+] HackOn{Kn0w_i_und3rst4nd_why_n0b0dy_d1d_th1s_b3f0r3_gcd_1s_p0w3rful!!}
[*] Closed connection to 0.cloud.chals.io port 18923
The full script can be found in here: solve.py
.