Composition
8 minutes to read
We are given the source code of the server in Python:
from Crypto.Util.number import isPrime, getPrime, GCD, long_to_bytes, bytes_to_long
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from secret import flag
from ecc import EllipticCurve
from hashlib import md5
import os
import random
print("Welcome to the ECRSA test center. Your encrypted data will be sent soon.")
print("Please check the logs for the parameters.")
legendre = lambda x,p: pow(x,(p-1)//2,p)
def next_prime(num):
if num % 2 == 0:
num += 1
else:
num += 2
while not isPrime(num):
num += 2
return num
def getrandpoint(ec,p,q):
num = random.randint(1,p*q)
while legendre(expr(num),p) != 1 or legendre(expr(num),q) != 1:
num = random.randint(1,p*q)
return ec.lift_x(num,p,q)
# Calculate discriminant(ensures elliptic curve is non-singular)
calc_discrim = lambda a,b,n: (-16 * (4 * a**3 + 27 * b**2)) % n
def keygen(bits):
# Returns RSA key in form ((e,n),(p,q))
p = getPrime(bits // 2)
while p % 4 == 1:
p = next_prime(p)
e = next_prime(p >> (bits // 4))
q = next_prime(p)
for i in range(50):
q = next_prime(q)
while q % 4 == 1:
q = next_prime(q)
n = p * q
if n.bit_length() != bits:
return keygen(bits)
return (e,n),(p,q)
print("Generating your key...")
key = keygen(512)
e,n = key[0]
p,q = key[1]
print("Creating ECC params")
# Gotta make sure the params are valid
a,b = random.getrandbits(128),random.getrandbits(128)
discrim = calc_discrim(a,b,n)
expr = lambda x: x**3 + a*x + b
while not discrim:
a,b = random.getrandbits(128),random.getrandbits(128)
discrim = calc_discrim(a,b,n)
ec = EllipticCurve(a,b,n)
g = getrandpoint(ec,p,q)
A = ec.multiply(g,e)
# Use key that has been shared with ECRSA
key = md5(str(g.x).encode()).digest()
iv = os.urandom(16)
cipher = AES.new(key,AES.MODE_CBC,iv)
data = cipher.encrypt(pad(flag,16))
print(f"Encrypted flag: {data.hex()}")
print(f"IV: {iv.hex()}")
print(f"N: {n}")
print(f"ECRSA Ciphertext: {A}")
print("Would you like to test the ECRSA curve?")
if input("[y/n]> ") == 'n':
exit()
print("Generating random point...")
print(getrandpoint(ec,p,q))
print("Thanks for testing!")
It encrypts the flag using ECC, which uses a custom library (ecc.py
):
from collections import namedtuple
from functools import reduce
from operator import mul
from Crypto.Util.number import inverse
import random
Point = namedtuple("Point","x y")
def moddiv(x,y,p):
return (x * inverse(y,p)) % p
def crt(*args):
# Takes a bunch of lists in form [value,modulus]
values = [row[0] for row in args]
ns = [row[1] for row in args]
N = reduce(mul,ns)
_sum = 0
for i in range(len(args)):
yi = N // ns[i]
zi = inverse(yi,ns[i])
_sum += values[i]*yi*zi
return _sum % N
def composite_square_root(num,p,q):
# Only works if num is a quadratic residue mod p and q AND p and q are 3 mod 4
n = p * q
root1 = pow(num,(p+1)//4,p)
root2 = pow(num,(q+1)//4,q)
ans = crt([root1,p],[root2,q])
assert pow(ans,2,n) == (num % n)
return ans
class EllipticCurve:
INF = Point(0,0)
def __init__(self, a, b, p):
self.a = a
self.b = b
self.p = p
def add(self,P,Q):
if P == self.INF:
return Q
elif Q == self.INF:
return P
if P.x == Q.x and P.y == (-Q.y % self.p):
return self.INF
if P != Q:
Lambda = moddiv(Q.y - P.y, Q.x - P.x, self.p)
else:
Lambda = moddiv(3 * P.x**2 + self.a,2 * P.y , self.p)
Rx = (Lambda**2 - P.x - Q.x) % self.p
Ry = (Lambda * (P.x - Rx) - P.y) % self.p
return Point(Rx,Ry)
def multiply(self,P,n):
n %= self.p
if n != abs(n):
ans = self.multiply(P,abs(n))
return Point(ans.x, -ans.y % p)
R = self.INF
while n > 0:
if n % 2 == 1:
R = self.add(R,P)
P = self.add(P,P)
n = n // 2
return R
def lift_x(self,x,p,q):
expr = x**3 + self.a*x + self.b
y = composite_square_root(expr,p,q)
return Point(x,y)
Source code analysis
The server creates an RSA public key
print("Generating your key...")
key = keygen(512)
e,n = key[0]
p,q = key[1]
# ...
print(f"N: {n}")
RSA
The public modulus
def keygen(bits):
# Returns RSA key in form ((e,n),(p,q))
p = getPrime(bits // 2)
while p % 4 == 1:
p = next_prime(p)
e = next_prime(p >> (bits // 4))
q = next_prime(p)
for i in range(50):
q = next_prime(q)
while q % 4 == 1:
q = next_prime(q)
n = p * q
if n.bit_length() != bits:
return keygen(bits)
return (e,n),(p,q)
As can be seen, next_prime
. At least, there is a distance of 50 prime numbers between
For some small
As a result, we have some approximation of
The result won’t be exactly next_prime
until we find
ECC
After that, the server creates an elliptic curve over
print("Creating ECC params")
# Gotta make sure the params are valid
a,b = random.getrandbits(128),random.getrandbits(128)
discrim = calc_discrim(a,b,n)
expr = lambda x: x**3 + a*x + b
while not discrim:
a,b = random.getrandbits(128),random.getrandbits(128)
discrim = calc_discrim(a,b,n)
ec = EllipticCurve(a,b,n)
Then, it generates a random point
g = getrandpoint(ec,p,q)
A = ec.multiply(g,e)
# Use key that has been shared with ECRSA
key = md5(str(g.x).encode()).digest()
iv = os.urandom(16)
cipher = AES.new(key,AES.MODE_CBC,iv)
data = cipher.encrypt(pad(flag,16))
print(f"Encrypted flag: {data.hex()}")
print(f"IV: {iv.hex()}")
print(f"N: {n}")
print(f"ECRSA Ciphertext: {A}")
Last but not least, we are given the chance to get another random point
print("Would you like to test the ECRSA curve?")
if input("[y/n]> ") == 'n':
exit()
print("Generating random point...")
print(getrandpoint(ec,p,q))
print("Thanks for testing!")
This is important to recover the curve parameters, because both
Subtracting both equations, we have:
So, we can isolate
And then it is trivial to find
Solution
So, up to this point, we know how to factor
The last thing we need is to find
What we can do is apply the Chinese Remainder Theorem in a peculiar way. Actually, we can find the inverse of a scalar when the curve is defined over a prime modulus. As a result, we can find the inverse of
Implementation
First of all, we get all the needed information from the challenge instance:
io = get_process()
io.recvuntil(b'Encrypted flag: ')
flag_enc = bytes.fromhex(io.recvline().strip().decode())
io.recvuntil(b'IV: ')
iv = bytes.fromhex(io.recvline().strip().decode())
io.recvuntil(b'N: ')
n = int(io.recvline().decode())
io.recvuntil(b'ECRSA Ciphertext: Point(x=')
Ax = int(io.recvuntil(b',')[:-1].decode())
io.recvuntil(b'y=')
Ay = int(io.recvuntil(b')')[:-1].decode())
io.sendlineafter(b'[y/n]> ', b'y')
io.recvuntil(b'Point(x=')
Rx = int(io.recvuntil(b',')[:-1].decode())
io.recvuntil(b'y=')
Ry = int(io.recvuntil(b')')[:-1].decode())
Then, we factor
q = isqrt(n)
while n % q != 0:
q = next_prime(q)
p = n // q
io.info(f'{p = }')
io.info(f'{q = }')
At this point, we also have
e = next_prime(p >> (n.bit_length() // 4))
io.info(f'{e = }')
Next, we find the curve parameters by solving the system of equations:
a = ((pow(Ay, 2, n) - pow(Ax, 3, n) - pow(Ry, 2, n) + pow(Rx, 3, n)) * pow(Ax - Rx, -1, n)) % n
b = (pow(Ay, 2, n) - pow(Ax, 3, n) - a * Ax) % n
io.info(f'{a = }')
io.info(f'{b = }')
After that, we can define the curves over
En = ecc.EllipticCurve(a, b, n)
Ep = EllipticCurve(GF(p), [a, b])
Eq = EllipticCurve(GF(q), [a, b])
Ap = Ep(Ax, Ay)
Aq = Eq(Ax, Ay)
Gp = pow(e, -1, Ep.order()) * Ap
Gq = pow(e, -1, Eq.order()) * Aq
Gn_x = crt([int(Gp.x()]), int(Gq.x()])], [p, q])
assert En.multiply(En.lift_x(Gn_x, p, q), e) == ecc.Point(Ax, Ay)
io.success(f'{Gn_x = }')
At this point, we can derive the AES key and get the flag:
key = md5(str(Gn_x).encode()).digest()
cipher = AES.new(key, AES.MODE_CBC, iv)
flag = unpad(cipher.decrypt(flag_enc), AES.block_size).decode()
io.success(flag)
Flag
If we run the script, we will get the flag:
$ python3 solve.py 94.237.63.83:33306
[+] Opening connection to 94.237.63.83 on port 33306: Done
[*] p = 106858941313638767991828235567490821200125820455933241643545899194829771009579
[*] q = 106858941313638767991828235567490821200125820455933241643545899194829771018591
[*] e = 314030204622581805680811063725039822717
[*] a = 37753997141981190002907126778194883468
[*] b = 37744197496150834695308441384696040008
[+] Gn_x = 4174411158930651296922645398385284926646467052831144880881853434751972793707369366842219778664803546408964829711287940023692621868869243079343630800011414
[+] HTB{Pr1M3_Pr0x1mIty_@nD_C0mP0s1T3_CuRv3?????=>s0_1nS3cUr3!!!}
[*] Closed connection to 94.237.63.83 port 33306
The full script can be found in here: solve.py
.