MSS
5 minutes to read
We are given the Python source code of the server:
import os, random, json
from hashlib import sha256
from Crypto.Util.number import bytes_to_long
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from secret import FLAG
class MSS:
def __init__(self, BITS, d, n):
self.d = d
self.n = n
self.BITS = BITS
self.key = bytes_to_long(os.urandom(BITS//8))
self.coeffs = [self.key] + [bytes_to_long(os.urandom(self.BITS//8)) for _ in range(self.d)]
def poly(self, x):
return sum([self.coeffs[i] * x**i for i in range(self.d+1)])
def get_share(self, x):
if x < 1 or x > 2**15:
return {'approved': 'False', 'reason': 'This scheme is intended for less users.'}
elif self.n < 1:
return {'approved': 'False', 'reason': 'Enough shares for today.'}
else:
self.n -= 1
return {'approved': 'True', 'x': x, 'y': self.poly(x)}
def encrypt_flag(self, m):
key = sha256(str(self.key).encode()).digest()
iv = os.urandom(16)
cipher = AES.new(key, AES.MODE_CBC, iv)
ct = cipher.encrypt(pad(m, 16))
return {'iv': iv.hex(), 'enc_flag': ct.hex()}
def show_banner():
print("""
# # ##### ##### # ###
## ## # # # # ## # #
# # # # # # # # # #
# # # ##### ##### # # # # #
# # # # # # # # #
# # # # # # # # # ## # #
# # ##### ##### ## ##### ## ###
This is a secure secret sharing scheme with really small threshold. We are pretty sure the key is secure...
""")
def show_menu():
return """
Send in JSON format any of the following commands.
- Get your share
- Encrypt flag
- Exit
query = """
def main():
mss = MSS(256, 30, 19)
show_banner()
while True:
try:
query = json.loads(input(show_menu()))
if 'command' in query:
cmd = query['command']
if cmd == 'get_share':
if 'x' in query:
x = int(query['x'])
share = mss.get_share(x)
print(json.dumps(share))
else:
print('\n[-] Please send your user ID.')
elif cmd == 'encrypt_flag':
enc_flag = mss.encrypt_flag(FLAG)
print(f'\n[+] Here is your encrypted flag : {json.dumps(enc_flag)}.')
elif cmd == 'exit':
print('\n[+] Thank you for using our service. Bye! :)')
break
else:
print('\n[-] Unknown command:(')
except KeyboardInterrupt:
exit(0)
except (ValueError, TypeError) as error:
print(error)
print('\n[-] Make sure your JSON query is properly formatted.')
pass
if __name__ == '__main__':
main()
Source code analysis
The challenge creates a polynomial with random coefficients, where the key
is the independent term:
class MSS:
def __init__(self, BITS, d, n):
self.d = d
self.n = n
self.BITS = BITS
self.key = bytes_to_long(os.urandom(BITS//8))
self.coeffs = [self.key] + [bytes_to_long(os.urandom(self.BITS//8)) for _ in range(self.d)]
def poly(self, x):
return sum([self.coeffs[i] * x**i for i in range(self.d+1)])
This key
is used to encrypt the flag with SHA256 and AES:
def encrypt_flag(self, m):
key = sha256(str(self.key).encode()).digest()
iv = os.urandom(16)
cipher = AES.new(key, AES.MODE_CBC, iv)
ct = cipher.encrypt(pad(m, 16))
return {'iv': iv.hex(), 'enc_flag': ct.hex()}
We are allowed to evaluate the polynomial and retrieve the result:
def get_share(self, x):
if x < 1 or x > 2**15:
return {'approved': 'False', 'reason': 'This scheme is intended for less users.'}
elif self.n < 1:
return {'approved': 'False', 'reason': 'Enough shares for today.'}
else:
self.n -= 1
return {'approved': 'True', 'x': x, 'y': self.poly(x)}
So, the objective is to evaluate the polynomial in some points and retrieve the key
in order to decrypt the flag.
Doing the maths
So, the polynomial can be expressed as
Notice that the 31 coefficients are 256-bit numbers, and we are allowed to get a maximum of 19 results (known as shares in this Mignotte Secret Sharing scheme):
mss = MSS(256, 30, 19)
As a result, we are not able to recover the coefficients of the polynomial because we would need a minimum of 31 shares.
Solution
The problem of this challenge is that the polynomial is defined over the integers. As a result, if we evaluate
Knowing this, we can do the same but with prime numbers. The objective is to get this system of congruences:
This system of modular congruences can be solved using the Chinese Remainder Theorem (CRT), which will output the value of
All these prime numbers must be less than
Implementation
We can program the above procedure in Python:
io = get_process()
primes, remainders = [], []
for _ in range(19):
p = getPrime(15)
io.sendlineafter(b'query = ', json.dumps(
{'command': 'get_share', 'x': p}).encode())
r = json.loads(io.recvline().decode()).get('y')
primes.append(p)
remainders.append(r % p)
key = crt(primes, remainders)[0]
io.sendlineafter(b'query = ', json.dumps({'command': 'encrypt_flag'}).encode())
io.recvuntil(b'[+] Here is your encrypted flag : ')
data = json.loads(io.recvuntil(b'}').decode())
iv = bytes.fromhex(data.get('iv'))
enc_flag = bytes.fromhex(data.get('enc_flag'))
key = sha256(str(key).encode()).digest()
cipher = AES.new(key, AES.MODE_CBC, iv)
flag = unpad(cipher.decrypt(enc_flag), AES.block_size).decode()
io.success(flag)
Flag
And here we have the flag:
$ python3 solve.py 94.237.54.170:55965
[+] Opening connection to 94.237.54.170 on port 55965: Done
[+] HTB{sm4ll_thr3sh0ld_n0_pr0bl3m_CRT_ru13s!}
[*] Closed connection to 94.237.54.170 port 55965
The full script can be found in here: solve.py
.