Oracle Leaks
7 minutos de lectura
Se nos proporciona el código fuente del servidor en Python:
import os
import math
from Crypto.Util.number import *
from cryptography.hazmat.primitives.asymmetric import rsa
def get_length(pt):
res = 0
if (len(bin(pt)) - 2) % 8 != 0:
res += 1
res += (len(bin(pt)) - 2) // 8
return res
def ceil(a, b):
return -(-a // b)
class RSA:
def __init__(self, size):
self.e = 0x10001
self.size = size
priv = rsa.generate_private_key(
public_exponent=self.e,
key_size=size
)
pub = priv.public_key()
self.n = pub.public_numbers().n
self.d = priv.private_numbers().d
self.n_size = ceil(self.size, 8)
self.B = 2**((self.n_size-1)*8)
def pad(self, pt):
res = 0x02 << 8 * (self.n_size - 2)
random_pad = os.urandom(self.n_size - 3 - get_length(pt))
for idx, val in enumerate(random_pad):
if val == 0:
val = 1
res += val << (len(random_pad) - idx + get_length(pt)) * 8
res += pt
return res
def encrypt(self,pt):
pt = bytes_to_long(pt)
padded_pt = self.pad(pt)
ct = pow(padded_pt, self.e, self.n)
return long_to_bytes(ct).hex()
def decrypt(self,ct):
ct = bytes_to_long(ct)
pt = pow(ct, self.d, self.n)
return pt
def main():
FLAG = b'HTB{dummyflag}'
size = 1024
tmp = RSA(size)
flag = tmp.encrypt(FLAG)
while True:
try:
print('Please choose:\n'+\
'1. Get public key.\n'+\
'2. Get encrypted flag.\n'+\
'3. Get length.\n'+\
'> ')
opt = input()
if opt == '1':
pub_key = (hex(tmp.n)[2:], hex(tmp.e)[2:])
print('(n,e): ' + str(pub_key) + '\n')
elif opt == '2':
print('Encrypted text: ' + flag + '\n')
elif opt == '3':
print('Provide a ciphertext:\n'+\
'> ')
ct = input()
ct = bytes.fromhex(ct)
pt = tmp.decrypt(ct)
length = get_length(pt)
print('Length: ' + str(length) + '\n')
else:
print('Wrong option!\n')
exit(1)
except Exception as e:
print(e)
print('Invalid Input. Exit!')
exit(1)
if __name__ == "__main__":
main()
Análisis del código fuente
El servidor usa RSA para cifrar la flag (con un relleno aleatorio):
def pad(self, pt):
res = 0x02 << 8 * (self.n_size - 2)
random_pad = os.urandom(self.n_size - 3 - get_length(pt))
for idx, val in enumerate(random_pad):
if val == 0:
val = 1
res += val << (len(random_pad) - idx + get_length(pt)) * 8
res += pt
return res
def encrypt(self,pt):
pt = bytes_to_long(pt)
padded_pt = self.pad(pt)
ct = pow(padded_pt, self.e, self.n)
return long_to_bytes(ct).hex()
El servidor ofrece tres opciones:
print('Please choose:\n'+\
'1. Get public key.\n'+\
'2. Get encrypted flag.\n'+\
'3. Get length.\n'+\
'> ')
La función relevante es la última:
elif opt == '3':
print('Provide a ciphertext:\n'+\
'> ')
ct = input()
ct = bytes.fromhex(ct)
pt = tmp.decrypt(ct)
length = get_length(pt)
print('Length: ' + str(length) + '\n')
Como se puede ver, podemos pasarle al servidor cualquier texto cifrado, lo descifra y nos devuelve la longitud de la cadena descifrada. Esta es decrypt
:
def decrypt(self,ct):
ct = bytes_to_long(ct)
pt = pow(ct, self.d, self.n)
return pt
Y esta es get_length
:
def get_length(pt):
res = 0
if (len(bin(pt)) - 2) % 8 != 0:
res += 1
res += (len(bin(pt)) - 2) // 8
return res
Puede parecer un poco extraño, pero simplemente devuelve la longitud de pt
en bytes.
Solución
Entonces, el reto parece claro: debemos usar get_length
para poder descifrar de alguna manera la flag (que está cifrada con la opción 2
).
Después de un poco de investigación, encontramos el ataque de Manger, que es un ataque de texto cifrado escogido en RSA-OAEP.
El ataque necesita de un oráculo de descifrado. En el paper, este oráculo es un servidor que recibe un texto cifrado $x$ y dice si $y = x^d \mod{n}$ es menor que $B$ o no, donde $B$ es un valor que depende de $n$.Después de eso, el paper define un algoritmo para recuperar un texto claro.
Oráculo
Esta vez, el servidor devuelve la longitud de la cadena descifrada. Nótese que $B = 2^{8(k - 1)}$ y $k = \lceil\log_{256}{n}\rceil$, como se muestra en el paper. Por lo tanto, $k$ es simplemente la longitud de $n$ en bytes.
Esta vez, $n$ es un número de 1024 bits, por lo que $k = 128$, y $B = 256^{127}$. Como resultado, si la longitud descifrada es exactamente 128 bytes, significa que $y \geqslant B$; y si la longitud es menor que 128 bytes, entonces $y < B$.
Con esto, hemos convertido el oráculo del reto al que aparece en el paper, por lo que solo tenemos que implementar el algoritmo.
Implementación
Hay una implementación del ataque en Go, que utilicé como base para construir el script de solución. Aunque el ataque también se puede realizar en Python (como en crypto-attacks), quería cambiar un poco y usar Go.
Definí las algunas funciones auxiliares para trabajar con procesos y conexiones (como si estuviéramos usando pwntools
en Python). De hecho, me gustó tanto la idea que acabé haciendo un módulo gopwntools
en Go.
Estas son las funciones auxiliares:
func getProcess() *pwn.Conn {
if len(os.Args) == 1 {
return pwn.Process("python3", "chall.py")
}
hostPort := strings.Split(os.Args[1], ":")
return pwn.Remote(hostPort[0], hostPort[1])
}
func divCeil(a, b *big.Int) *big.Int {
quo, rem := new(big.Int).QuoRem(a, b, new(big.Int))
if rem.Cmp(zero) > 0 {
quo.Add(quo, one)
}
return quo
}
Solo para aclarar, divCeil
es como una implementación de la función techo (ceiling) pero utilizando aritmética modular. Por ejemplo:
divCeil(15, 4) = 4
porque $15 / 4 = 3$ y $15 \mod{4} = 3 \ne 0$, entonces el cociente se incrementa al siguiente número enterodivCeil(16, 4) = 4
porque $16 / 4 = 4$ y $16 \mod{4} = 0$
Y estas son variables globales que se usan muchas veces en el algoritmo:
var (
io *pwn.Conn
e = big.NewInt(65537)
zero = big.NewInt(0)
one = big.NewInt(1)
two = big.NewInt(2)
)
Podemos crear la conexión al proceso/instancia y tomar la clave pública y el texto cifrado:
func main() {
io = getProcess()
defer io.Close()
io.SendLineAfter([]byte("> "), []byte{'1'})
io.RecvUntil([]byte("('"))
n, _ := new(big.Int).SetString(io.RecvUntilS([]byte("'"), true), 16)
io.SendLineAfter([]byte("> "), []byte{'2'})
io.RecvUntil([]byte("Encrypted text: "))
c, _ := new(big.Int).SetString(strings.TrimSpace(io.RecvLineS()), 16)
La función clave del algoritmo es oracle
:
func oracle(x, c, n *big.Int) bool {
test := new(big.Int).Mod(new(big.Int).Mul(new(big.Int).Exp(x, e, n), c), n)
io.SendLineAfter([]byte("> "), []byte{'3'})
io.SendLineAfter([]byte("> "), []byte(pwn.Hex(test.Bytes())))
return strings.Contains(io.RecvLineContainsS([]byte("Length: ")), "128")
}
La función toma un número $x$ y enviará al servidor el valor de $\mathrm{test} = c \cdot x^e \mod{n}$. Como resultado, el servidor calcula
$$ y = \mathrm{test}^d = (c \cdot x^e)^d = c^d x = \mathrm{flag} \cdot x \mod{n} $$
y devuelve la longitud de $y$ en bytes. La función en verdad dice si la longitud es igual a 128 bytes o no.
Algoritmo
El algoritmo se explica en el paper, y es fácil de seguir.
Aunque ya los conocemos, podemos calcular $k$ y $B$:
k := n.BitLen() / 8
B := new(big.Int).Exp(two, big.NewInt(int64(8*(k-1))), nil)
Luego comenzamos con el paso 1:
// Step 1
f1 := new(big.Int).Set(one)
for !oracle(f1.Mul(two, f1), c, n) {
}
El objetivo es encontrar un valor $f_1$ tal que $(f_1 \cdot c)^d \mod{n} \geqslant B$. En otras palabras, $f_1^d \cdot \mathrm{flag} \mod{n} > B$. Si el oráculo devuelve falso, entonces multiplicamos $f_1$ por $2$ y lo intentamos de nuevo.
El código para el paso 2 es este:
// Step 2
f12 := new(big.Int).Div(f1, two)
nB := new(big.Int).Add(n, B)
nBB := new(big.Int).Div(nB, B)
f2 := new(big.Int).Mul(nBB, f12)
for oracle(f2.Add(f2, f12), c, n) {
}
Go y el uso del tipo big.Int
es muy raro y redundante, pero el código está haciendo el siguiente cálculo:
$$ f_2 = \left\lfloor\frac{n + B}{B}\right\rfloor \cdot \frac{f_1}{2} $$
Queremos que el oráculo devuelva falso (menor que $B$). Si el oráculo devuelve verdadero, entonces sumamos $\frac{f_1}{2}$ a $f_2$ y probamos de nuevo.
Finalmente, este es el paso 3:
// Step 3
mmin := divCeil(n, f2)
mmax := new(big.Int).Div(nB, f2)
BB := new(big.Int).Mul(two, B)
diff := new(big.Int).Sub(mmax, mmin)
for diff.Sub(mmax, mmin).Cmp(zero) > 0 {
ftmp := new(big.Int).Div(BB, diff)
ftmpmmin := new(big.Int).Mul(ftmp, mmin)
i := new(big.Int).Div(ftmpmmin, n)
iN := new(big.Int).Mul(i, n)
iNB := new(big.Int).Add(iN, B)
f3 := divCeil(iN, mmin)
if oracle(f3, c, n) {
mmin = divCeil(iNB, f3)
} else {
mmax.Div(iNB, f3)
}
}
Este es más complicado, así que véase el paper para más información.
En este punto, tenemos el texto claro, por lo que podamos eliminar el relleno y extraer la flag:
splitted := strings.Split(string(mmin.Bytes()), "\x00")
flag := splitted[len(splitted)-1]
pwn.Success("Flag: " + flag)
}
Flag
Si ejecutamos el script, obtendremos la flag (en nuestro entorno local y en la instancia remota):
$ go run solve.go
[+] Starting local process 'python3': pid 78786
[+] Flag: HTB{dummyflag}
[*] Stopped process 'python3' (pid 78786)
$ go run solve.go 94.237.54.176:46699
[+] Opening connection to 94.237.54.176 on port 46699: Done
[+] Flag: HTB{m4ng3r5_4tt4ck_15_c001_4nd_und3rv4lu3d_341m3f}
[*] Closed connection to 94.237.54.176 port 46699
El script completo se puede encontrar aquí: solve.go
.