Easy DSA: The beginning
4 minutes to read
We are given the Python source code used to encrypt the flag:
from Crypto.Util.Padding import pad
from Crypto.Util.number import isPrime, getPrime, long_to_bytes
from Crypto.Cipher import AES
from hashlib import sha256
from random import randrange
def gen_key():
p = 0
while not isPrime(p):
q = getPrime(300)
p = 2*q + 1
g = randrange(2, p)**2 % p
k = randrange(2, q)
x = randrange(2, q)
y = pow(g, x, p)
return p, q, g, x, y, k
def H(msg):
return int.from_bytes(sha256(msg).digest(), 'big')
def sign(m):
r = pow(g, k, p) % q
s = (H(m) + x*r) * pow(k, -1, q) % q
return r, s
def verify(m, r, s):
assert 0 < r < q and 0 < s < q
u = pow(s, -1, q)
v = pow(g, H(m) * u, p) * pow(y, r * u, p) % p % q
return v == r
flag = b"ictf{REDACTED}"
p, q, g, x, y, k = gen_key()
ms = b"jctf{powered_by_caffeine}", b"jctf{totally_real_flag}"
sigs = [sign(m) for m in ms]
assert all(verify(m, *sig) for m, sig in zip(ms, sigs))
aes = AES.new(long_to_bytes(x)[:16], AES.MODE_CBC, b'\0'*16)
c = aes.encrypt(pad(flag, 16)).hex()
print(f'{p = }\n{g = }\n{y = }\n{ms = }\n{sigs = }\n{c = }')
And the output of the above script:
p = 2187927460624367866053138955407692648682473743053236246707558906253741042480006602164664427
g = 375559713231366027661456501312210678588547344177468345614581736759352578046940519482449005
y = 1485107098193369513854775432342726913250546508148678604594096036026212707003506931382110518
ms = (b'jctf{powered_by_caffeine}', b'jctf{totally_real_flag}')
sigs = [(584760320483109456978677291524162809623560744424005784846002481292183647634857441612413242, 43566017108108194938809536454030127793515021629016835721136006757000695802735201944729583), (584760320483109456978677291524162809623560744424005784846002481292183647634857441612413242, 587754055422977160798386807229397695762555861352726788417293238718373985110611538922057038)]
c = '614585db552484e4c81c4168afa8582bd975bfadd5edc8a4d1bf744c29a7d84f30cde5fe4b37f736af3f09480bcb626a'
Source code analysis
The server implements a Digital Signature Algorithm (DSA), and uses a private parameter x
as a key to encrypt the flag with AES:
aes = AES.new(long_to_bytes(x)[:16], AES.MODE_CBC, b'\0'*16)
c = aes.encrypt(pad(flag, 16)).hex()
Therefore, we will need to find x
from the DSA signatures.
DSA implementation
The code for the DSA is correct (more information here):
- It generates prime numbers $p$ and $q$ (such that $p = 2q + 1$)
- Then public values $g$ and $y$
- And private values $k$ and $x$
In order to sign messages, it uses SHA256 as hash function and computes:
$$ r = g^k \mod{q} $$
$$ s = \left(\mathrm{SHA256}{(m)} + x \cdot r\right) \cdot k^{-1} \mod{q} $$
And the output is the pair $(r, s)$.
The security flaw
The problem here is that we have two different messages signed with the same nonce value $k$. In DSA, the nonce value $k$ is meant to be used just once. Otherwise, we can find the private value $x$ and be able to sign our own messages.
So, we have message $m_1$ with signature $(r_1, s_1)$ and message $m_2$ with signature $(r_2, s_2)$. Notice that $r_1 = r_2$ because $r = g^k \mod{q}$, and $k$ is the same for both signatures.
Therefore, we have
$$ \begin{cases} s_1 = \left(\mathrm{SHA256}{(m_1)} + x \cdot r\right) \cdot k^{-1} \mod{q} \\ s_2 = \left(\mathrm{SHA256}{(m_2)} + x \cdot r\right) \cdot k^{-1} \mod{q} \end{cases} $$
If we subtract both equations, we get
$$ s_1 - s_2 = \left(\mathrm{SHA256}{(m_1)} - \mathrm{SHA256}{(m_2)}\right) \cdot k^{-1} \mod{q} $$
So we can find $k^{-1}$ as follows
$$ k^{-1} = (s_1 - s_2) \cdot \left(\mathrm{SHA256}{(m_1)} - \mathrm{SHA256}{(m_2)}\right)^{-1} \mod{q} $$
Then $k = \left(k^{-1}\right)^{-1} \mod{q}$. And once we have $k$, we are able to find $x$:
$$ x = \left(s_1 \cdot k - \mathrm{SHA256}{(m_1)}\right) \cdot r^{-1} \mod{q} $$
Flag
If we reproduce the above calculations in Python, we will find x
and therefore we can decrypt the AES ciphertext to find the flag:
$ python3 -q
>>> from Crypto.Cipher import AES
>>> from Crypto.Util.number import long_to_bytes
>>> from Crypto.Util.Padding import unpad
>>> from hashlib import sha256
>>>
>>> def H(msg):
... return int.from_bytes(sha256(msg).digest(), 'big')
...
>>>
>>> p = 2187927460624367866053138955407692648682473743053236246707558906253741042480006602164664427
>>> g = 375559713231366027661456501312210678588547344177468345614581736759352578046940519482449005
>>> y = 1485107098193369513854775432342726913250546508148678604594096036026212707003506931382110518
>>> ms = (b'jctf{powered_by_caffeine}', b'jctf{totally_real_flag}')
>>> sigs = [(584760320483109456978677291524162809623560744424005784846002481292183647634857441612413242, 43566017108108194938809536454030127793515021629016835721136006757000695802735201944729583), (584760320483109456978677291524162809623560744424005784846002481292183647634857441612413242, 587754055422977160798386807229397695762555861352726788417293238718373985110611538922057038)]
>>> c = '614585db552484e4c81c4168afa8582bd975bfadd5edc8a4d1bf744c29a7d84f30cde5fe4b37f736af3f09480bcb626a'
>>>
>>> q = (p - 1) // 2
>>> r = sigs[0][0]
>>> s1, s2 = sigs[0][1], sigs[1][1]
>>> H1, H2 = H(ms[0]), H(ms[1])
>>>
>>> k_inv = (s1 - s2) * pow(H1 - H2, -1, q) % q
>>> k = pow(k_inv, -1, q)
>>> x = (s1 * k - H1) * pow(r, -1, q) % q
>>>
>>> aes = AES.new(long_to_bytes(x)[:16], AES.MODE_CBC, b'\0'*16)
>>> unpad(aes.decrypt(bytes.fromhex(c)), 16)
b'ictf{4_n0nc3_5h0u!d_0n!y_b3_u53d_0NC3!!!}'