Composition
8 minutes to read
We are given the source code of the server in Python:
from Crypto.Util.number import isPrime, getPrime, GCD, long_to_bytes, bytes_to_long
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from secret import flag
from ecc import EllipticCurve
from hashlib import md5
import os
import random
print("Welcome to the ECRSA test center. Your encrypted data will be sent soon.")
print("Please check the logs for the parameters.")
legendre = lambda x,p: pow(x,(p-1)//2,p)
def next_prime(num):
if num % 2 == 0:
num += 1
else:
num += 2
while not isPrime(num):
num += 2
return num
def getrandpoint(ec,p,q):
num = random.randint(1,p*q)
while legendre(expr(num),p) != 1 or legendre(expr(num),q) != 1:
num = random.randint(1,p*q)
return ec.lift_x(num,p,q)
# Calculate discriminant(ensures elliptic curve is non-singular)
calc_discrim = lambda a,b,n: (-16 * (4 * a**3 + 27 * b**2)) % n
def keygen(bits):
# Returns RSA key in form ((e,n),(p,q))
p = getPrime(bits // 2)
while p % 4 == 1:
p = next_prime(p)
e = next_prime(p >> (bits // 4))
q = next_prime(p)
for i in range(50):
q = next_prime(q)
while q % 4 == 1:
q = next_prime(q)
n = p * q
if n.bit_length() != bits:
return keygen(bits)
return (e,n),(p,q)
print("Generating your key...")
key = keygen(512)
e,n = key[0]
p,q = key[1]
print("Creating ECC params")
# Gotta make sure the params are valid
a,b = random.getrandbits(128),random.getrandbits(128)
discrim = calc_discrim(a,b,n)
expr = lambda x: x**3 + a*x + b
while not discrim:
a,b = random.getrandbits(128),random.getrandbits(128)
discrim = calc_discrim(a,b,n)
ec = EllipticCurve(a,b,n)
g = getrandpoint(ec,p,q)
A = ec.multiply(g,e)
# Use key that has been shared with ECRSA
key = md5(str(g.x).encode()).digest()
iv = os.urandom(16)
cipher = AES.new(key,AES.MODE_CBC,iv)
data = cipher.encrypt(pad(flag,16))
print(f"Encrypted flag: {data.hex()}")
print(f"IV: {iv.hex()}")
print(f"N: {n}")
print(f"ECRSA Ciphertext: {A}")
print("Would you like to test the ECRSA curve?")
if input("[y/n]> ") == 'n':
exit()
print("Generating random point...")
print(getrandpoint(ec,p,q))
print("Thanks for testing!")
It encrypts the flag using ECC, which uses a custom library (ecc.py
):
from collections import namedtuple
from functools import reduce
from operator import mul
from Crypto.Util.number import inverse
import random
Point = namedtuple("Point","x y")
def moddiv(x,y,p):
return (x * inverse(y,p)) % p
def crt(*args):
# Takes a bunch of lists in form [value,modulus]
values = [row[0] for row in args]
ns = [row[1] for row in args]
N = reduce(mul,ns)
_sum = 0
for i in range(len(args)):
yi = N // ns[i]
zi = inverse(yi,ns[i])
_sum += values[i]*yi*zi
return _sum % N
def composite_square_root(num,p,q):
# Only works if num is a quadratic residue mod p and q AND p and q are 3 mod 4
n = p * q
root1 = pow(num,(p+1)//4,p)
root2 = pow(num,(q+1)//4,q)
ans = crt([root1,p],[root2,q])
assert pow(ans,2,n) == (num % n)
return ans
class EllipticCurve:
INF = Point(0,0)
def __init__(self, a, b, p):
self.a = a
self.b = b
self.p = p
def add(self,P,Q):
if P == self.INF:
return Q
elif Q == self.INF:
return P
if P.x == Q.x and P.y == (-Q.y % self.p):
return self.INF
if P != Q:
Lambda = moddiv(Q.y - P.y, Q.x - P.x, self.p)
else:
Lambda = moddiv(3 * P.x**2 + self.a,2 * P.y , self.p)
Rx = (Lambda**2 - P.x - Q.x) % self.p
Ry = (Lambda * (P.x - Rx) - P.y) % self.p
return Point(Rx,Ry)
def multiply(self,P,n):
n %= self.p
if n != abs(n):
ans = self.multiply(P,abs(n))
return Point(ans.x, -ans.y % p)
R = self.INF
while n > 0:
if n % 2 == 1:
R = self.add(R,P)
P = self.add(P,P)
n = n // 2
return R
def lift_x(self,x,p,q):
expr = x**3 + self.a*x + self.b
y = composite_square_root(expr,p,q)
return Point(x,y)
Source code analysis
The server creates an RSA public key $(e, n)$, although we are not given $e$:
print("Generating your key...")
key = keygen(512)
e,n = key[0]
p,q = key[1]
# ...
print(f"N: {n}")
RSA
The public modulus $n = p \cdot q$, where $p$ and $q$ are two big prime numbers:
def keygen(bits):
# Returns RSA key in form ((e,n),(p,q))
p = getPrime(bits // 2)
while p % 4 == 1:
p = next_prime(p)
e = next_prime(p >> (bits // 4))
q = next_prime(p)
for i in range(50):
q = next_prime(q)
while q % 4 == 1:
q = next_prime(q)
n = p * q
if n.bit_length() != bits:
return keygen(bits)
return (e,n),(p,q)
As can be seen, $p$ is a 256-bit prime number, but $q$ is pretty close to $p$, because it is computed using next_prime
. At least, there is a distance of 50 prime numbers between $p$ and $q$, but the difference is still short, so we can say
$$ q = p + k $$
For some small $k \in \mathbb{Z}$. Then, we have
$$ n = p \cdot q = p^2 + p \cdot k $$
As a result, we have some approximation of $p$ using the square root:
$$ p \approx \sqrt{n} $$
The result won’t be exactly $p$, but we can use next_prime
until we find $q$ (and then we have $p$ as well). The public exponent $e$ depends on $p$, so we would also have $e$ at this point.
ECC
After that, the server creates an elliptic curve over $\mathbb{Z}/n\mathbb{Z}$ with random parameters $a$ and $b$, which are not given to us:
print("Creating ECC params")
# Gotta make sure the params are valid
a,b = random.getrandbits(128),random.getrandbits(128)
discrim = calc_discrim(a,b,n)
expr = lambda x: x**3 + a*x + b
while not discrim:
a,b = random.getrandbits(128),random.getrandbits(128)
discrim = calc_discrim(a,b,n)
ec = EllipticCurve(a,b,n)
Then, it generates a random point $G$, which is used to derive an AES key to encrypt the flag, and we are given $A = e \cdot G$:
g = getrandpoint(ec,p,q)
A = ec.multiply(g,e)
# Use key that has been shared with ECRSA
key = md5(str(g.x).encode()).digest()
iv = os.urandom(16)
cipher = AES.new(key,AES.MODE_CBC,iv)
data = cipher.encrypt(pad(flag,16))
print(f"Encrypted flag: {data.hex()}")
print(f"IV: {iv.hex()}")
print(f"N: {n}")
print(f"ECRSA Ciphertext: {A}")
Last but not least, we are given the chance to get another random point $R$ within the curve:
print("Would you like to test the ECRSA curve?")
if input("[y/n]> ") == 'n':
exit()
print("Generating random point...")
print(getrandpoint(ec,p,q))
print("Thanks for testing!")
This is important to recover the curve parameters, because both $A$ and $R$ lie on the curve, so
$$ \begin{cases} A_y^2 = A_x^3 + a A_x + b \mod{n} \\ R_y^2 = R_y^3 + a R_y + b \mod{n} \\ \end{cases} $$
Subtracting both equations, we have:
$$ A_y^2 - R_y^2 = A_x^3 - R_y^3 + a (A_x - R_y) \mod{n} $$
So, we can isolate $a$:
$$ a = (A_y^2 - R_y^2 - A_x^3 + R_y^3) \cdot (A_x - R_y)^{-1} \mod{n} $$
And then it is trivial to find $b$:
$$ b = A_y^2 - A_x^3 - a A_x \mod{n} $$
Solution
So, up to this point, we know how to factor $n = p \cdot q$, how to get $e$ and how to recover the curve parameters.
The last thing we need is to find $G$ such that $A = e \cdot G$. The way to do this is to find a value $d$ such that $G = d \cdot A$, so $e$ and $d$ are somehow inverses. The problem is that we cannot compute the order of the curve because it is defined over a composite modulus, so we can’t find the inverse of $e$ easily.
What we can do is apply the Chinese Remainder Theorem in a peculiar way. Actually, we can find the inverse of a scalar when the curve is defined over a prime modulus. As a result, we can find the inverse of $e$ on the curves defined over $\mathbb{F}_p$ and $\mathbb{F}_q$, get the point $G$ on both curves, and then apply the CRT of the $\mathrm{x}$ coordinates to get the point $G$ over $\mathbb{Z}/n\mathbb{Z}$.
Implementation
First of all, we get all the needed information from the challenge instance:
io = get_process()
io.recvuntil(b'Encrypted flag: ')
flag_enc = bytes.fromhex(io.recvline().strip().decode())
io.recvuntil(b'IV: ')
iv = bytes.fromhex(io.recvline().strip().decode())
io.recvuntil(b'N: ')
n = int(io.recvline().decode())
io.recvuntil(b'ECRSA Ciphertext: Point(x=')
Ax = int(io.recvuntil(b',')[:-1].decode())
io.recvuntil(b'y=')
Ay = int(io.recvuntil(b')')[:-1].decode())
io.sendlineafter(b'[y/n]> ', b'y')
io.recvuntil(b'Point(x=')
Rx = int(io.recvuntil(b',')[:-1].decode())
io.recvuntil(b'y=')
Ry = int(io.recvuntil(b')')[:-1].decode())
Then, we factor $n$ as explained previously:
q = isqrt(n)
while n % q != 0:
q = next_prime(q)
p = n // q
io.info(f'{p = }')
io.info(f'{q = }')
At this point, we also have $e$:
e = next_prime(p >> (n.bit_length() // 4))
io.info(f'{e = }')
Next, we find the curve parameters by solving the system of equations:
a = ((pow(Ay, 2, n) - pow(Ax, 3, n) - pow(Ry, 2, n) + pow(Rx, 3, n)) * pow(Ax - Rx, -1, n)) % n
b = (pow(Ay, 2, n) - pow(Ax, 3, n) - a * Ax) % n
io.info(f'{a = }')
io.info(f'{b = }')
After that, we can define the curves over $\mathbb{Z}/n\mathbb{Z}$, $\mathbb{F}_p$ and $\mathbb{F}_q$ to get the point $G$ over the last two curves and fing $G$ over the first one by using the CRT:
En = ecc.EllipticCurve(a, b, n)
Ep = EllipticCurve(GF(p), [a, b])
Eq = EllipticCurve(GF(q), [a, b])
Ap = Ep(Ax, Ay)
Aq = Eq(Ax, Ay)
Gp = pow(e, -1, Ep.order()) * Ap
Gq = pow(e, -1, Eq.order()) * Aq
Gn_x = crt([int(Gp.x()]), int(Gq.x()])], [p, q])
assert En.multiply(En.lift_x(Gn_x, p, q), e) == ecc.Point(Ax, Ay)
io.success(f'{Gn_x = }')
At this point, we can derive the AES key and get the flag:
key = md5(str(Gn_x).encode()).digest()
cipher = AES.new(key, AES.MODE_CBC, iv)
flag = unpad(cipher.decrypt(flag_enc), AES.block_size).decode()
io.success(flag)
Flag
If we run the script, we will get the flag:
$ python3 solve.py 94.237.63.83:33306
[+] Opening connection to 94.237.63.83 on port 33306: Done
[*] p = 106858941313638767991828235567490821200125820455933241643545899194829771009579
[*] q = 106858941313638767991828235567490821200125820455933241643545899194829771018591
[*] e = 314030204622581805680811063725039822717
[*] a = 37753997141981190002907126778194883468
[*] b = 37744197496150834695308441384696040008
[+] Gn_x = 4174411158930651296922645398385284926646467052831144880881853434751972793707369366842219778664803546408964829711287940023692621868869243079343630800011414
[+] HTB{Pr1M3_Pr0x1mIty_@nD_C0mP0s1T3_CuRv3?????=>s0_1nS3cUr3!!!}
[*] Closed connection to 94.237.63.83 port 33306
The full script can be found in here: solve.py
.