Easy DSA: The beginning
4 minutes to read
We are given the Python source code used to encrypt the flag:
from Crypto.Util.Padding import pad
from Crypto.Util.number import isPrime, getPrime, long_to_bytes
from Crypto.Cipher import AES
from hashlib import sha256
from random import randrange
def gen_key():
p = 0
while not isPrime(p):
q = getPrime(300)
p = 2*q + 1
g = randrange(2, p)**2 % p
k = randrange(2, q)
x = randrange(2, q)
y = pow(g, x, p)
return p, q, g, x, y, k
def H(msg):
return int.from_bytes(sha256(msg).digest(), 'big')
def sign(m):
r = pow(g, k, p) % q
s = (H(m) + x*r) * pow(k, -1, q) % q
return r, s
def verify(m, r, s):
assert 0 < r < q and 0 < s < q
u = pow(s, -1, q)
v = pow(g, H(m) * u, p) * pow(y, r * u, p) % p % q
return v == r
flag = b"ictf{REDACTED}"
p, q, g, x, y, k = gen_key()
ms = b"jctf{powered_by_caffeine}", b"jctf{totally_real_flag}"
sigs = [sign(m) for m in ms]
assert all(verify(m, *sig) for m, sig in zip(ms, sigs))
aes = AES.new(long_to_bytes(x)[:16], AES.MODE_CBC, b'\0'*16)
c = aes.encrypt(pad(flag, 16)).hex()
print(f'{p = }\n{g = }\n{y = }\n{ms = }\n{sigs = }\n{c = }')
And the output of the above script:
p = 2187927460624367866053138955407692648682473743053236246707558906253741042480006602164664427
g = 375559713231366027661456501312210678588547344177468345614581736759352578046940519482449005
y = 1485107098193369513854775432342726913250546508148678604594096036026212707003506931382110518
ms = (b'jctf{powered_by_caffeine}', b'jctf{totally_real_flag}')
sigs = [(584760320483109456978677291524162809623560744424005784846002481292183647634857441612413242, 43566017108108194938809536454030127793515021629016835721136006757000695802735201944729583), (584760320483109456978677291524162809623560744424005784846002481292183647634857441612413242, 587754055422977160798386807229397695762555861352726788417293238718373985110611538922057038)]
c = '614585db552484e4c81c4168afa8582bd975bfadd5edc8a4d1bf744c29a7d84f30cde5fe4b37f736af3f09480bcb626a'
Source code analysis
The server implements a Digital Signature Algorithm (DSA), and uses a private parameter x
as a key to encrypt the flag with AES:
aes = AES.new(long_to_bytes(x)[:16], AES.MODE_CBC, b'\0'*16)
c = aes.encrypt(pad(flag, 16)).hex()
Therefore, we will need to find x
from the DSA signatures.
DSA implementation
The code for the DSA is correct (more information here):
- It generates prime numbers
and (such that ) - Then public values
and - And private values
and
In order to sign messages, it uses SHA256 as hash function and computes:
And the output is the pair
The security flaw
The problem here is that we have two different messages signed with the same nonce value
So, we have message
Therefore, we have
If we subtract both equations, we get
So we can find
Then
Flag
If we reproduce the above calculations in Python, we will find x
and therefore we can decrypt the AES ciphertext to find the flag:
$ python3 -q
>>> from Crypto.Cipher import AES
>>> from Crypto.Util.number import long_to_bytes
>>> from Crypto.Util.Padding import unpad
>>> from hashlib import sha256
>>>
>>> def H(msg):
... return int.from_bytes(sha256(msg).digest(), 'big')
...
>>>
>>> p = 2187927460624367866053138955407692648682473743053236246707558906253741042480006602164664427
>>> g = 375559713231366027661456501312210678588547344177468345614581736759352578046940519482449005
>>> y = 1485107098193369513854775432342726913250546508148678604594096036026212707003506931382110518
>>> ms = (b'jctf{powered_by_caffeine}', b'jctf{totally_real_flag}')
>>> sigs = [(584760320483109456978677291524162809623560744424005784846002481292183647634857441612413242, 43566017108108194938809536454030127793515021629016835721136006757000695802735201944729583), (584760320483109456978677291524162809623560744424005784846002481292183647634857441612413242, 587754055422977160798386807229397695762555861352726788417293238718373985110611538922057038)]
>>> c = '614585db552484e4c81c4168afa8582bd975bfadd5edc8a4d1bf744c29a7d84f30cde5fe4b37f736af3f09480bcb626a'
>>>
>>> q = (p - 1) // 2
>>> r = sigs[0][0]
>>> s1, s2 = sigs[0][1], sigs[1][1]
>>> H1, H2 = H(ms[0]), H(ms[1])
>>>
>>> k_inv = (s1 - s2) * pow(H1 - H2, -1, q) % q
>>> k = pow(k_inv, -1, q)
>>> x = (s1 * k - H1) * pow(r, -1, q) % q
>>>
>>> aes = AES.new(long_to_bytes(x)[:16], AES.MODE_CBC, b'\0'*16)
>>> unpad(aes.decrypt(bytes.fromhex(c)), 16)
b'ictf{4_n0nc3_5h0u!d_0n!y_b3_u53d_0NC3!!!}'