Living with Elegance
6 minutes to read
We are given the following Python script that encrypts the flag:
from secrets import token_bytes, randbelow
from Crypto.Util.number import bytes_to_long as b2l
class ElegantCryptosystem:
def __init__(self):
self.d = 16
self.n = 256
self.S = token_bytes(self.d)
def noise_prod(self):
return randbelow(2*self.n//3) - self.n//2
def get_encryption(self, bit):
A = token_bytes(self.d)
b = self.punc_prod(A, self.S) % self.n
e = self.noise_prod()
if bit == 1:
return A, b + e
else:
return A, randbelow(self.n)
def punc_prod(self, x, y):
return sum(_x * _y for _x, _y in zip(x, y))
def main():
FLAGBIN = bin(b2l(open('flag.txt', 'rb').read()))[2:]
crypto = ElegantCryptosystem()
while True:
idx = input('Specify the index of the bit you want to get an encryption for : ')
if not idx.isnumeric():
print('The index must be an integer.')
continue
idx = int(idx)
if idx < 0 or idx >= len(FLAGBIN):
print(f'The index must lie in the interval [0, {len(FLAGBIN)-1}]')
continue
bit = int(FLAGBIN[idx])
A, b = crypto.get_encryption(bit)
print('Here is your ciphertext: ')
print(f'A = {b2l(A)}')
print(f'b = {b}')
if __name__ == '__main__':
main()
Source code analysis
The server allows us to get the encryption result for each bit we indicate from the flag:
while True:
idx = input('Specify the index of the bit you want to get an encryption for : ')
if not idx.isnumeric():
print('The index must be an integer.')
continue
idx = int(idx)
if idx < 0 or idx >= len(FLAGBIN):
print(f'The index must lie in the interval [0, {len(FLAGBIN)-1}]')
continue
bit = int(FLAGBIN[idx])
A, b = crypto.get_encryption(bit)
print('Here is your ciphertext: ')
print(f'A = {b2l(A)}')
print(f'b = {b}')
The ciphertext is composed by two values: A
and b
. Let’s see how they are computed:
class ElegantCryptosystem:
def __init__(self):
self.d = 16
self.n = 256
self.S = token_bytes(self.d)
def noise_prod(self):
return randbelow(2*self.n//3) - self.n//2
def get_encryption(self, bit):
A = token_bytes(self.d)
b = self.punc_prod(A, self.S) % self.n
e = self.noise_prod()
if bit == 1:
return A, b + e
else:
return A, randbelow(self.n)
def punc_prod(self, x, y):
return sum(_x * _y for _x, _y in zip(x, y))
- When the class
ElegantCryptosystem
is initialized, a valueS
is computed, as a 16-byte number - Then, on each encryption, the server computes a random value
A
- Then multiplies
A
andS
using a scalar product (punc_prod
) to getb
e
is some random noise- If the bit is
1
, the result isA
andb + e
; otherwise it isA
and a random number belown = 256
Learning With Errors
In math terms, we can express A
and S
as vectors of 8-bit integers:
$$ A = (a_0, a_1, \dots, a_{15}) \qquad S = (s_0, s_1, \dots, s_{15}) $$
Then, b
is just their scalar product:
$$ b = A \circ S = \sum_{i = 0}^{15} a_i \cdot s_i $$
And if the bit to encrypt is a 1
, then we receive the following ciphertext:
$$ \begin{cases} A = (a_0, a_1, \dots, a_{15}) \\ c = A \circ S + e = b + e \end{cases} $$
This cryptosystem is known as Learning With Errors (LWE), and the name of the challenge is kind of a hint. The idea behind this cryptosystem is basically that it hard to recover the secret key $S$ from $A$ and $c$ because $c$ is just $b$ corrupted with some random noise $e$.
Solution
This time, we need to find a way to tell if we are been given b + e
or randbelow(self.n)
. We can do this by looking at the boundaries that these values can take:
- We know that
b
must be an integer value in $[0, 256)$ because it is reduced modulon = 256
- The value of
e
comes fromnoise_prod
, which is computed asrandbelow(2*self.n//3) - self.n//2
. So, we have the following boundary fore
:
$$ [0, 2 \cdot 256 / 3) - 256 / 2 = [0, 170) - 128 = [-128, 42) $$
- Therefore, the boundary for
b + e
is:
$$ [0, 256) + [-128, 42) = [-128, 298) $$
- On the other hand,
randbelow(self.n)
is an integer value in $[0, 256)$
As a result, we have a way to determine if the bit is 1
: That is, if we receive a value $c$ such that $c < 0$ or $c \geqslant 256$, then we can be sure that it came from b + e
, so the encrypted bit is a 1
.
Since we have unlimited queries to the server, we can perform a probabilistic approach. For instance, we query 30 times for the same bit:
- If any of these queries return $c < 0$ or $c \geqslant 256$, then we know for sure that it’s a
1
and we stop the process to continue - If none of the queries give a result outside of $[0, 256)$, then we can assume with high probability that the encrypted bit is a
0
We can repeat this process for all the bits until we have the flag.
Implementation
This time, I am using Go with my gopwntools
module. This is a helper function to query the server on a given bit index (it only returns the ciphertext value, because the A
is not relevant):
func getEncryption(index int) int {
io.SendLineAfter([]byte("Specify the index of the bit you want to get an encryption for : "), []byte(strconv.Itoa(index)))
io.RecvUntil([]byte("b = "))
c, _ := strconv.Atoi(strings.TrimSpace(io.RecvLineS()))
return c
}
We can find the bit length of the flag (the server will send an error if the queried index is not valid and will show the maximum index) and start finding bits with the above procedure:
func main() {
io = getProcess()
defer io.Close()
io.SendLineAfter([]byte("Specify the index of the bit you want to get an encryption for : "), []byte("10000"))
io.RecvUntil([]byte("The index must lie in the interval [0, "))
bitLength, _ := strconv.Atoi(io.RecvUntilS([]byte{']'}, true))
bitLength++
bits := make([]int, bitLength)
prog := pwn.Progress("Bits")
for i := 0; i < bitLength; i++ {
prog.Status(fmt.Sprintf("%d / %d", i+1, bitLength))
for range 30 {
c := getEncryption(i)
if c < 0 || 256 < c {
bits[i] = 1
break
}
}
}
prog.Success(fmt.Sprintf("%[1]d / %[1]d", bitLength))
Next, we need to apply padding to fill 8-bit blocks (for later decoding):
for len(bits)%8 != 0 {
bits = append([]int{0}, bits...)
}
Finally, we decode each 8-bit block to a byte and show the flag as a string:
flag := make([]byte, len(bits)/8)
for i := 0; i < bitLength; i += 8 {
for j, v := range bits[i : i+8] {
flag[i/8] |= byte(v << (7 - j))
}
}
pwn.Success(string(flag))
}
Flag
With this script, we get the flag:
$ echo 'HTB{f4k3_fl4g_f0r_t3st1ng}' > flag.txt
$ go run solve.go
[+] Starting local process 'python3': pid 91076
[+] Bits: 215 / 215
[+] HTB{f4k3_fl4g_f0r_t3st1ng}
[*] Stopped process 'python3' (pid 91076)
$ go run solve.go 94.237.52.200:59555
[+] Opening connection to 94.237.52.200 on port 59555: Done
[+] Bits: 175 / 175
[+] HTB{s3cur3_cust0m_LW3}
[*] Closed connection to 94.237.52.200 port 59555
The full script can be found in here: solve.go
.