baby quick maffs
2 minutes to read
We are given a Python script to encrypt the flag, and we are given the output.txt
file:
#!/usr/bin/env python3
from secret import flag, p, q
from Crypto.Util.number import bytes_to_long
from random import randint
def partition_message(m, N):
m1 = randint(1, N)
parts = []
remainder = 0
while sum(parts) < m:
if sum(parts) + m1 < m:
parts.append(m1)
else:
remainder = m - sum(parts)
parts.append(m1 + remainder)
return (parts, remainder)
def encode(message, N):
m = bytes_to_long(message)
parts, remainder = partition_message(m, N)
ciphers = [pow(c, 2, N) for c in parts]
return (ciphers, remainder)
N = p * q
ciphers, remainder = encode(flag, N)
with open("output.txt", "w") as f:
out = f'{N}\n{remainder}\n{ciphers}'
f.write(out)
6083782486455360611313889289556658208725888944237734041722591252756006664878102248734673207367745303402874595854966731263105387801996693270011840173939423
1081087287982224274239399953615475281184099226198643053396569433856757255106426461817760194704250226883807897800355728788149068771546876055268915238961343
[5408283916250636369066846815501131861319520431106165986129813106223074286810632222888292034380612581416458756909119954039579666773680866532576166358987272, 5408283916250636369066846815501131861319520431106165986129813106223074286810632222888292034380612581416458756909119954039579666773680866532576166358987272, 5598555010250184271123226314796180406367795504188162611960100902143581636125416986623404842897202277277978566659455918773104687212096435095590205751904580]
Analyizing the encryption mechanism
Let’s see what we have. The encryption consists of taking a random number remainder
in the code. Each of the partitions is encrypted as:
Where output.txt
. Notice that
Finding the flaw
We have these values:
Notice that
Hence,
And thus,
And now we got
Python implementation
Here is a Python script to solve the challenge:
#!/usr/bin/env python3
def main():
with open('output.txt') as f:
N = int(f.readline())
remainder = int(f.readline())
a, b, c = eval(f.readline())
m1 = pow(2 * remainder, -1, N) * (c - a - remainder ** 2) % N
m = m1 + m1 + remainder
print(bytes.fromhex(hex(m)[2:]).decode())
if __name__ == '__main__':
main()
Flag
$ python3 solve.py
HTB{d0nt_ev4_r3l4ted_m3ss4ge_att4cks_th3y_ar3_@_d3a1_b7eak3r!!!}
The full script can be found in here: solve.py
.