Oracle Leaks
7 minutes to read
We are given the Python source code of the server:
import os
import math
from Crypto.Util.number import *
from cryptography.hazmat.primitives.asymmetric import rsa
def get_length(pt):
res = 0
if (len(bin(pt)) - 2) % 8 != 0:
res += 1
res += (len(bin(pt)) - 2) // 8
return res
def ceil(a, b):
return -(-a // b)
class RSA:
def __init__(self, size):
self.e = 0x10001
self.size = size
priv = rsa.generate_private_key(
public_exponent=self.e,
key_size=size
)
pub = priv.public_key()
self.n = pub.public_numbers().n
self.d = priv.private_numbers().d
self.n_size = ceil(self.size, 8)
self.B = 2**((self.n_size-1)*8)
def pad(self, pt):
res = 0x02 << 8 * (self.n_size - 2)
random_pad = os.urandom(self.n_size - 3 - get_length(pt))
for idx, val in enumerate(random_pad):
if val == 0:
val = 1
res += val << (len(random_pad) - idx + get_length(pt)) * 8
res += pt
return res
def encrypt(self,pt):
pt = bytes_to_long(pt)
padded_pt = self.pad(pt)
ct = pow(padded_pt, self.e, self.n)
return long_to_bytes(ct).hex()
def decrypt(self,ct):
ct = bytes_to_long(ct)
pt = pow(ct, self.d, self.n)
return pt
def main():
FLAG = b'HTB{dummyflag}'
size = 1024
tmp = RSA(size)
flag = tmp.encrypt(FLAG)
while True:
try:
print('Please choose:\n'+\
'1. Get public key.\n'+\
'2. Get encrypted flag.\n'+\
'3. Get length.\n'+\
'> ')
opt = input()
if opt == '1':
pub_key = (hex(tmp.n)[2:], hex(tmp.e)[2:])
print('(n,e): ' + str(pub_key) + '\n')
elif opt == '2':
print('Encrypted text: ' + flag + '\n')
elif opt == '3':
print('Provide a ciphertext:\n'+\
'> ')
ct = input()
ct = bytes.fromhex(ct)
pt = tmp.decrypt(ct)
length = get_length(pt)
print('Length: ' + str(length) + '\n')
else:
print('Wrong option!\n')
exit(1)
except Exception as e:
print(e)
print('Invalid Input. Exit!')
exit(1)
if __name__ == "__main__":
main()
Source code analysis
The server uses RSA to encrypt the flag (with a random padding):
def pad(self, pt):
res = 0x02 << 8 * (self.n_size - 2)
random_pad = os.urandom(self.n_size - 3 - get_length(pt))
for idx, val in enumerate(random_pad):
if val == 0:
val = 1
res += val << (len(random_pad) - idx + get_length(pt)) * 8
res += pt
return res
def encrypt(self,pt):
pt = bytes_to_long(pt)
padded_pt = self.pad(pt)
ct = pow(padded_pt, self.e, self.n)
return long_to_bytes(ct).hex()
The server offers three options:
print('Please choose:\n'+\
'1. Get public key.\n'+\
'2. Get encrypted flag.\n'+\
'3. Get length.\n'+\
'> ')
The relevant function is the last one:
elif opt == '3':
print('Provide a ciphertext:\n'+\
'> ')
ct = input()
ct = bytes.fromhex(ct)
pt = tmp.decrypt(ct)
length = get_length(pt)
print('Length: ' + str(length) + '\n')
As can be seen, we can give the server any ciphertext, it decrypts it and it returns the length of the decrypted string. This is decrypt
:
def decrypt(self,ct):
ct = bytes_to_long(ct)
pt = pow(ct, self.d, self.n)
return pt
And this is get_length
:
def get_length(pt):
res = 0
if (len(bin(pt)) - 2) % 8 != 0:
res += 1
res += (len(bin(pt)) - 2) // 8
return res
It might look a bit weird, but it just returns the length of pt
in bytes.
Solution
So, the challenge seems clear: we must use get_length
to somehow be able to decrypt the flag (which is encrypted with option 2
).
After a bit of research, we come up with Manger’s attack, which is a chosen ciphertext attack on RSA-OAEP.
The attack needs to have a decryption oracle. In the paper, this oracle is a server that receives a ciphertext $x$ and tells if $y = x^d \mod{n}$ is less than $B$ or not, where $B$ is a value that depends on $n$. After that, the paper defines an algorithm to recover some plaintext.
Oracle
This time, the server returns the length of the decrypted string. Notice that $B = 2^{8(k - 1)}$ and $k = \lceil\log_{256}{n}\rceil$, as shown in the paper. Therefore, $k$ is simply the length of $n$ in bytes.
This time, $n$ is a 1024-bit number, so $k = 128$, and $B = 256^{127}$. As a result, if the decrypted length is exactly 128 bytes, it means that $y \geqslant B$; and if the length is less than 128 bytes, then $y < B$.
With this, we have converted the challenge oracle to the one that appears in the paper, so we only have to implement the algorithm.
Implementation
There is an attack implementation in Go, which I used as a base to build the solution script. Although it can also be implemented in Python (like in crypto-attacks), I wanted to change a bit and use Go.
I defined the some helper functions to work with processes and connections (as if we were using pwntools
in Python). Actually, I liked the idea and wrote a gopwntools
module in Go.
These are the helper functions:
func getProcess() *pwn.Conn {
if len(os.Args) == 1 {
return pwn.Process("python3", "chall.py")
}
hostPort := strings.Split(os.Args[1], ":")
return pwn.Remote(hostPort[0], hostPort[1])
}
func divCeil(a, b *big.Int) *big.Int {
quo, rem := new(big.Int).QuoRem(a, b, new(big.Int))
if rem.Cmp(zero) > 0 {
quo.Add(quo, one)
}
return quo
}
Just to clarify, divCeil
is like an implementation of the ceiling function but using modular arithmetic. For example:
divCeil(15, 4) = 4
because $15 / 4 = 3$ and $15 \mod{4} = 3 \ne 0$, so the quotient is increased to the next integerdivCeil(16, 4) = 4
because $16 / 4 = 4$ and $16 \mod{4} = 0$
And these are global variables that are used many times in the algorithm:
var (
io *pwn.Conn
e = big.NewInt(65537)
zero = big.NewInt(0)
one = big.NewInt(1)
two = big.NewInt(2)
)
We can create the connection to the process/instance and take the public key and ciphertext:
func main() {
io = getProcess()
defer io.Close()
io.SendLineAfter([]byte("> "), []byte{'1'})
io.RecvUntil([]byte("('"))
n, _ := new(big.Int).SetString(io.RecvUntilS([]byte("'"), true), 16)
io.SendLineAfter([]byte("> "), []byte{'2'})
io.RecvUntil([]byte("Encrypted text: "))
c, _ := new(big.Int).SetString(strings.TrimSpace(io.RecvLineS()), 16)
The key function of the algorithm is oracle
:
func oracle(x, c, n *big.Int) bool {
test := new(big.Int).Mod(new(big.Int).Mul(new(big.Int).Exp(x, e, n), c), n)
io.SendLineAfter([]byte("> "), []byte{'3'})
io.SendLineAfter([]byte("> "), []byte(pwn.Hex(test.Bytes())))
return strings.Contains(io.RecvLineContainsS([]byte("Length: ")), "128")
}
The function takes a number $x$ and will send to the server the value of $\mathrm{test} = c \cdot x^e \mod{n}$. As a result, the server computes
$$ y = \mathrm{test}^d = (c \cdot x^e)^d = c^d x = \mathrm{flag} \cdot x \mod{n} $$
and returns the length of $y$ in bytes. Actually, the function tells if the length is 128 bytes or not.
Algorithm
The algorithm is explained in the paper, and it is easy to follow.
Although we already know them, we can compute $k$ and $B$:
k := n.BitLen() / 8
B := new(big.Int).Exp(two, big.NewInt(int64(8*(k-1))), nil)
Then we start with the step 1:
// Step 1
f1 := new(big.Int).Set(one)
for !oracle(f1.Mul(two, f1), c, n) {
}
The aim is to find a value $f_1$ such that $(f_1 \cdot c)^d \mod{n} \geqslant B$. In other words, $f_1^d \cdot \mathrm{flag} \mod{n} > B$. If the oracle returns false, then we multiply $f_1$ times $2$ and try again.
The code for step 2 is this one:
// Step 2
f12 := new(big.Int).Div(f1, two)
nB := new(big.Int).Add(n, B)
nBB := new(big.Int).Div(nB, B)
f2 := new(big.Int).Mul(nBB, f12)
for oracle(f2.Add(f2, f12), c, n) {
}
Go and big.Int
types are very weird and verbose to use, but the code is doing the following computation:
$$ f_2 = \left\lfloor\frac{n + B}{B}\right\rfloor \cdot \frac{f_1}{2} $$
We want the oracle to return false (less than $B$). If the oracle returns true, then we add $\frac{f_1}{2}$ to $f_2$ and try again.
Finally, this is step 3:
// Step 3
mmin := divCeil(n, f2)
mmax := new(big.Int).Div(nB, f2)
BB := new(big.Int).Mul(two, B)
diff := new(big.Int).Sub(mmax, mmin)
for diff.Sub(mmax, mmin).Cmp(zero) > 0 {
ftmp := new(big.Int).Div(BB, diff)
ftmpmmin := new(big.Int).Mul(ftmp, mmin)
i := new(big.Int).Div(ftmpmmin, n)
iN := new(big.Int).Mul(i, n)
iNB := new(big.Int).Add(iN, B)
f3 := divCeil(iN, mmin)
if oracle(f3, c, n) {
mmin = divCeil(iNB, f3)
} else {
mmax.Div(iNB, f3)
}
}
This one is more convoluted, so refer to the paper for more information.
At this point, we have the plaintext, so we can remove the padding and extract the flag:
splitted := strings.Split(string(mmin.Bytes()), "\x00")
flag := splitted[len(splitted)-1]
pwn.Success("Flag: " + flag)
}
Flag
If we run the script, we will get the flag (in our local environment and in the remote instance):
$ go run solve.go
[+] Starting local process 'python3': pid 78786
[+] Flag: HTB{dummyflag}
[*] Stopped process 'python3' (pid 78786)
$ go run solve.go 94.237.54.176:46699
[+] Opening connection to 94.237.54.176 on port 46699: Done
[+] Flag: HTB{m4ng3r5_4tt4ck_15_c001_4nd_und3rv4lu3d_341m3f}
[*] Closed connection to 94.237.54.176 port 46699
The full script can be found in here: solve.go
.