baby quick maffs
2 minutes to read
We are given a Python script to encrypt the flag, and we are given the output.txt
file:
#!/usr/bin/env python3
from secret import flag, p, q
from Crypto.Util.number import bytes_to_long
from random import randint
def partition_message(m, N):
m1 = randint(1, N)
parts = []
remainder = 0
while sum(parts) < m:
if sum(parts) + m1 < m:
parts.append(m1)
else:
remainder = m - sum(parts)
parts.append(m1 + remainder)
return (parts, remainder)
def encode(message, N):
m = bytes_to_long(message)
parts, remainder = partition_message(m, N)
ciphers = [pow(c, 2, N) for c in parts]
return (ciphers, remainder)
N = p * q
ciphers, remainder = encode(flag, N)
with open("output.txt", "w") as f:
out = f'{N}\n{remainder}\n{ciphers}'
f.write(out)
6083782486455360611313889289556658208725888944237734041722591252756006664878102248734673207367745303402874595854966731263105387801996693270011840173939423
1081087287982224274239399953615475281184099226198643053396569433856757255106426461817760194704250226883807897800355728788149068771546876055268915238961343
[5408283916250636369066846815501131861319520431106165986129813106223074286810632222888292034380612581416458756909119954039579666773680866532576166358987272, 5408283916250636369066846815501131861319520431106165986129813106223074286810632222888292034380612581416458756909119954039579666773680866532576166358987272, 5598555010250184271123226314796180406367795504188162611960100902143581636125416986623404842897202277277978566659455918773104687212096435095590205751904580]
Analyizing the encryption mechanism
Let’s see what we have. The encryption consists of taking a random number $m_1$ so that $1 \leqslant m_1 \leqslant N$, where $N = p q$. Then it takes the flag as $m$ (decimal format) and partitions it like $m = m_1 + m_1 + r$ (there are three elements in the output list), where $r$ is remainder
in the code. Each of the partitions is encrypted as:
$$ a = m_1^2 \mod{N} $$
$$ b = m_1^2 \mod{N} $$
$$ c = (m_1 + r)^2 \mod{N} $$
Where $a, b, c$ are the three values inside the list in output.txt
. Notice that $a = b$, obviously.
Finding the flaw
We have these values: $N$, $r$ and $[a, b, c]$.
Notice that
$$ c = (m_1 + r)^2 = m_1^2 + 2 m_1 r + r^2 = a + 2 m_1 r + r^2 \mod{N} $$
Hence,
$$ 2 m_1 r = c - a - r^2 \mod{N} $$
And thus,
$$ m_1 = (2 r)^{-1} \cdot (c - a - r^2) \mod{N} $$
And now we got $m_1$, we can find $m = m_1 + m_1 + r$.
Python implementation
Here is a Python script to solve the challenge:
#!/usr/bin/env python3
def main():
with open('output.txt') as f:
N = int(f.readline())
remainder = int(f.readline())
a, b, c = eval(f.readline())
m1 = pow(2 * remainder, -1, N) * (c - a - remainder ** 2) % N
m = m1 + m1 + remainder
print(bytes.fromhex(hex(m)[2:]).decode())
if __name__ == '__main__':
main()
Flag
$ python3 solve.py
HTB{d0nt_ev4_r3l4ted_m3ss4ge_att4cks_th3y_ar3_@_d3a1_b7eak3r!!!}
The full script can be found in here: solve.py
.