AESWCM
7 minutes to read
We got the Python source code of the server:
from Crypto.Util.Padding import pad
from Crypto.Cipher import AES
import os
import random
from secret import FLAG
KEY = os.urandom(16)
IV = os.urandom(16)
class AESWCM:
def __init__(self, key):
self.key = key
self.cipher = AES.new(self.key, AES.MODE_ECB)
self.BLOCK_SIZE = 16
def pad(self, pt):
if len(pt) % self.BLOCK_SIZE != 0:
pt = pad(pt, self.BLOCK_SIZE)
return pt
def blockify(self, message):
return [
message[i:i + self.BLOCK_SIZE]
for i in range(0, len(message), self.BLOCK_SIZE)
]
def xor(self, a, b):
return bytes([aa ^ bb for aa, bb in zip(a, b)])
def encrypt(self, pt, iv):
pt = self.pad(pt)
blocks = self.blockify(pt)
xor_block = iv
ct = []
for block in blocks:
ct_block = self.cipher.encrypt(self.xor(block, xor_block))
xor_block = self.xor(block, ct_block)
ct.append(ct_block)
return b"".join(ct).hex()
def decrypt(self, ct, iv):
ct = bytes.fromhex(ct)
blocks = self.blockify(ct)
xor_block = iv
pt = []
for block in blocks:
pt_block = self.xor(self.cipher.decrypt(block), xor_block)
xor_block = self.xor(block, pt_block)
pt.append(pt_block)
return b"".join(pt)
def tag(self, pt, iv=os.urandom(16)):
blocks = self.blockify(bytes.fromhex(self.encrypt(pt, iv)))
random.shuffle(blocks)
ct = blocks[0]
for i in range(1, len(blocks)):
ct = self.xor(blocks[i], ct)
return ct.hex()
def main():
aes = AESWCM(KEY)
tags = []
characteristics = []
print("What properties should your magic wand have?")
message = "Property: "
counter = 0
while counter < 3:
characteristic = bytes.fromhex(input(message))
if characteristic not in characteristics:
characteristics.append(characteristic)
characteristic_tag = aes.tag(message.encode() + characteristic, IV)
tags.append(characteristic_tag)
print(characteristic_tag)
if len(tags) > len(set(tags)):
print(FLAG)
counter += 1
else:
print("Only different properties are allowed!")
exit(1)
if __name__ == "__main__":
main()
Source code analysis
First of all, the program initializes KEY
and IV
variables to 16 random bytes. After that, an instance of AESWCM
class is created.
Then, we are asked to enter a message that will be tagged. Our message (characteristic
) will be appended to the messages list (characteristics
). And the tag will be appended to the tags list (tags
).
We will see the flag when len(tags) > len(set(tags))
; that is, when there is a repeated element in tags
(because there are no duplicated elements in a set
).
The code below reflects the above explanation:
message = "Property: "
counter = 0
while counter < 3:
characteristic = bytes.fromhex(input(message))
if characteristic not in characteristics:
characteristics.append(characteristic)
characteristic_tag = aes.tag(message.encode() + characteristic, IV)
tags.append(characteristic_tag)
print(characteristic_tag)
if len(tags) > len(set(tags)):
print(FLAG)
counter += 1
else:
print("Only different properties are allowed!")
exit(1)
Analyzing the encryption algorithm
The way of tagging a message (characteristic
) is by calling the tag
method and entering message.encode() + characteristic
(that is "Property: "
plus our input data).
This is the tag
method:
def tag(self, pt, iv=os.urandom(16)):
blocks = self.blockify(bytes.fromhex(self.encrypt(pt, iv)))
random.shuffle(blocks)
ct = blocks[0]
for i in range(1, len(blocks)):
ct = self.xor(blocks[i], ct)
return ct.hex()
What it does is: it encrypts the plaintext using method encrypt
, splits the result in blocks of 16 bytes using blockify
and finally, it shuffles the blocks and applies XOR to all of them.
Here we can see that the shuffle is useless, because XOR is commutative, the order does not matter.
Use of AES and XOR
The method encrypt
is this one:
def encrypt(self, pt, iv):
pt = self.pad(pt)
blocks = self.blockify(pt)
xor_block = iv
ct = []
for block in blocks:
ct_block = self.cipher.encrypt(self.xor(block, xor_block))
xor_block = self.xor(block, ct_block)
ct.append(ct_block)
return b"".join(ct).hex()
First, it pads the plaintext so that its length is a multiple of 16 (the block size). After that, it is divided into 16-byte blocks and does the encryption.
Let’s recall that the cipher is AES ECB:
class AESWCM:
def __init__(self, key):
self.key = key
self.cipher = AES.new(self.key, AES.MODE_ECB)
self.BLOCK_SIZE = 16
However, the encrypt
function is rather similar to AES CBC mode, but not exactly the same. Below, you can see how is AES CBC encryption:
This time, the feedbacks from each block to the next are not the ciphertext blocks, but a XOR between the ciphertext blocks with the plaintext blocks. Say we have 3 plaintext blocks ($p_1$, $p_2$, $p_3$):
$$ c_1 = \mathrm{AES}\big(p_1 \oplus \mathrm{IV}\big) $$ $$ c_2 = \mathrm{AES}\big(p_2 \oplus (p_1 \oplus c_1)\big) $$ $$ c_3 = \mathrm{AES}\big(p_3 \oplus (p_2 \oplus c_2)\big) $$
Padding implementation
Padding is needed for AES because all blocks must be sized 16 bytes. Hence, the length of the plaintext must be a multiple of 16. The method self.pad
is this one:
def pad(self, pt):
if len(pt) % self.BLOCK_SIZE != 0:
pt = pad(pt, self.BLOCK_SIZE)
return pt
Basically, if the length of the plaintext is not divisible by 16, it calls pad
(from Crypto.Util.Padding
, which applies PKCS7 padding).
Here we have a misconfiguration. PKCS7 padding (if applied) must be applied to all plaintexts, regardless if it is multiple of 16 or not. And it is programmed as so:
$ python3 -q
>>> from Crypto.Util.Padding import pad
>>> pad(b'A' * 16, 16)
b'AAAAAAAAAAAAAAAA\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10'
>>> pad(b'A' * 32, 16)
b'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10'
PKCS7 padding just takes the number of remaining bytes and sets that number as the padding byte (say we need 3 more bytes to reach a multiple of 16, the padding will be "\x03\x03\x03"
). In the above output, we already have a plaintext with a size divisible by 16, and still pad
appends 16 bytes "\x10"
as padding.
Abusing wrong implementation
Let’s recall that we will see the flag when we enter two different messages (characteristic
) and the outputs are equal (tag
), so that the length of the list is bigger than the length of the set.
To do so, we can enter a blank message (that will be padded), and then enter a fake padding. For instance:
>>> def wrong_pad(pt):
... if len(pt) % 16 != 0:
... return pad(pt, 16)
... return pt
...
>>> wrong_pad(b'Property: ' + b'')
b'Property: \x06\x06\x06\x06\x06\x06'
>>> wrong_pad(b'Property: ' + b'\x06\x06\x06\x06\x06\x06')
b'Property: \x06\x06\x06\x06\x06\x06'
That’s it: two different messages, same padded plaintext; and hence, same ciphertext (tag
).
Flag
To enter the message, we must do it in hexadecimal format:
$ nc 178.62.21.211 32535
What properties should your magic wand have?
Property:
aed8fd6e4d0b9210ddaccffbf63ed737
Property: 060606060606
aed8fd6e4d0b9210ddaccffbf63ed737
HTB{435_cu570m_m0d35_4nd_hm4c_423_fun!}
Intended solution
For the intended solution we are going to avoid padding entering messages sized a multiple of 16 bytes. Let $p_{i,j}$ be the plaintext message we send to the server, $c_{i,j}$ the output of the AES encryption and $e_i$ the result of the XOR operation between the blocks, where $i$ represents the round number and $j$ is the index of the block.
For the first round we will enter a single block, and we will have $e_1 = c_{1,1}$, where
- $c_{1,1} = \mathrm{AES}(p_{1,1} \oplus \mathrm{IV})$.
Then we will set:
- $p_{2,1} = p_{1,1}$
- $p_{2,2} = e_1 \oplus p_{1,1}$
So $e_2 = c_{2,1} \oplus c_{2,2}$, where
- $c_{2,1} = \mathrm{AES}(p_{2,1} \oplus \mathrm{IV})$
- $c_{2,2} = \mathrm{AES}\big(p_{2,2} \oplus (p_{2,1} \oplus c_{2,1})\big)$.
Notice that $c_{2,1} = c_{1,1} = e_1$ and
$$ c_{2,2} = \mathrm{AES}\big((e_1 \oplus p_{1,1}) \oplus (p_{1,1} \oplus e_1)\big) = \mathrm{AES}(0) $$
Then we can simplify $e_2 = e_1 \oplus \mathrm{AES}(0)$.
Finally, we will set
- $p_{3,1} = p_{2,1}$
- $p_{3,2} = p_{2,2}$
- $p_{3,3} = p_{2,2} \oplus c_{2,2}$
Then we have $e_3 = c_{3,1} \oplus c_{3,2} \oplus c_{3,3}$, where
- $c_{3,1} = \mathrm{AES}(p_{3,1} \oplus \mathrm{IV})$
- $c_{3,2} = \mathrm{AES}\big(p_{3,2} \oplus (p_{3,1} \oplus c_{3,1})\big)$
- $c_{3,3} = \mathrm{AES}\big(p_{3,3} \oplus (p_{3,2} \oplus c_{3,2})\big)$
Notice that
- $c_{3,1} = c_{2,1} = c_{1,1} = e_1$
- $c_{3,2} = c_{2,2} = \mathrm{AES}(0)$
On the other hand,
$$ c_{3,3} = \mathrm{AES}\big((p_{2,2} \oplus c_{2,2}) \oplus (p_{3,2} \oplus c_{3,2})\big) = \mathrm{AES}(0) $$
Then we will have $e_3 = c_{3,1} \oplus c_{3,2} \oplus c_{3,3} = e_1 \oplus \mathrm{AES}(0) \oplus \mathrm{AES}(0) = e_1$.
Since $e_3 = e_1$, then the list of tags will have a duplicated element and we will get the flag.
Implementation in Python
def main():
io = get_process()
p_1_1 = b'Property: ' + bytes.fromhex('00' * 6)
io.sendlineafter(b'Property: ', p_1_1[10:].hex().encode())
e_1 = c_1_1 = bytes.fromhex(io.recvline().decode())
p_2_1 = p_1_1
p_2_2 = xor(e_1, p_1_1)
io.sendlineafter(b'Property: ', (p_2_1 + p_2_2)[10:].hex().encode())
e_2 = bytes.fromhex(io.recvline().decode())
c_2_1 = c_1_1
c_2_2 = xor(e_2, c_2_1)
p_3_1 = p_2_1
p_3_2 = p_2_2
p_3_3 = xor(p_2_2, c_2_2)
io.sendlineafter(b'Property: ', (p_3_1 + p_3_2 + p_3_3)[10:].hex().encode())
io.recvline()
log.success(f'Flag: {io.recvline()}')
io.close()
$ python3 solve.py 178.62.21.211:32535
[+] Opening connection to 178.62.21.211 on port 32535: Done
[+] Flag: b'HTB{435_cu570m_m0d35_4nd_hm4c_423_fun!}'
[*] Closed connection to 178.62.21.211 port 32535
The full script can be found in here: solve.py
.