cryptoGRAPHy (1, 2, 3)
23 minutes to read
This series of challenges has a background on graph theory combined with cryptography. The author implemented a Python library using networkx
to handle graphs and cryptographyc functions such as AES cipher, HMAC or SHA256 hash.
These challenges were a bit polemic due to the fact that players needed to read, analyze and understand the implemented library for the Graph Encryption Scheme and after that implement a solution to the corresponding challenge. Fortunately, the library didn’t change between levels. Nevertheless, some teams complained because the solution didn’t involve much cryptography, but a lot of graph theory and programming skills.
Moreover, the first two challenges were a bit introductory to getting comfortable with the library. The last one was actually inspired by a realistic attack on a Graph Encryption Scheme, so it is way more interesting.
Anyways, I really enjoyed the challenges because they made me dust off my knowledge on graph theory and discrete math, so let’s start!
Graph Encryption Scheme
In the library, we have three files (fortunately, most functions and classes have docstrings):
GES.py
: Implementation of graph encryption schemeDES.py
: Implementation of dictionary encryption schemeutils.py
: Utility functions
Understanding the encryption
Taking a look at server.py
from cryptoGRAPHy 1, we can see that the server will always start by defining a graph (with networkx
). Then it will generate a key with GESClass.keyGen
(which calls DESClass.keyGen
under the hood):
class GESClass:
# ...
def keyGen(self, security_parameter: int) -> bytes:
'''
Input: Security parameter
Output: Secret key key_SKE||key_DES
'''
key_SKE = get_random_bytes(security_parameter)
key_DES = DES.keyGen(security_parameter)
return key_SKE + key_DES
# ...
class DESClass:
# ...
def keyGen(self, security_parameter: int) -> bytes:
'''
Input: Security parameter
Output: Secret key
'''
return get_random_bytes(security_parameter)
# ...
Once the key is set, then the graph is encrypted using GESClass.encryptGraph
:
class GESClass:
# ...
def encryptGraph(self, key: bytes, G: nx.Graph) -> dict[bytes, bytes]:
'''
Input: Secret key and a graph G
Output: Encrypted graph encrypted_db
'''
SPDX = computeSPDX(key, G, self.cores)
key_DES = key[16:]
EDB = DES.encryptDict(key_DES, SPDX, self.cores)
del(SPDX)
gc.collect()
return EDB
# ...
As can be seen, the function calls computeSPDX
:
def computeSPDX(key: bytes, G: nx.Graph, cores: int) -> dict[bytes, bytes]:
SPDX = {}
chunk = round(len(G.nodes())/cores)
key_SKE = key[:16]
key_DES = key[16:]
with Pool(cores) as pool:
iterable = product([G], G)
for S in pool.istarmap(computeSDSP, iterable, chunksize=chunk):
for pair in S:
label, value = pair[0], pair[1]
label_bytes = utils.pair_to_bytes(label)
value_bytes = utils.pair_to_bytes(value)
if label_bytes not in SPDX:
token = DES.tokenGen(key_DES, value_bytes)
ct = utils.SymmetricEncrypt(key_SKE,value_bytes)
ct_value = token + ct
SPDX[label_bytes] = ct_value
return SPDX
This function might be a bit difficult to understand (also, there is no docstring), because it calls yet another function named computeSDSP
, which is executed inside pool.istarmap
to speed up the process. The summary is that computeSDSP
is called with arguments G
and a node of the graph, for every node in the graph. This is computeSDSP
:
def computeSDSP(G: nx.Graph, root):
'''
Input: Graph G and a root
Output: Tuples of the form ((start, root), (next_vertex, root))
'''
paths = nx.single_source_shortest_path(G, root)
S = set()
for _, path in paths.items():
path.reverse()
if len(path) > 1:
for i in range(len(path)-1):
label = (path[i], root)
value = (path[i+1],root)
S.add((label, value))
return S
Again, this function might be hard to understand. The gist is that the function takes root
as destination and computes all shortest paths to root
from any other source node of the graph (paths
). For each path, the different nodes of the path are inserted into a set using format ((node, root), (next, root))
. I interpret ((node, root), (next, root))
like:
If I am in
node
and I want to reachroot
, i will move tonext
.
In the end, the set will contain the longest shortest path to root
in the graph. Moreover, since it is stored as (label, value)
(where label = (node, root)
and value = (next, root)
), we can easily find the shortest path to root
from any node of the graph. Actually, this is called a single-destination shortest path (SDSP).
Let’s review function computeSDPX
again:
def computeSPDX(key: bytes, G: nx.Graph, cores: int) -> dict[bytes, bytes]:
SPDX = {}
chunk = round(len(G.nodes())/cores)
key_SKE = key[:16]
key_DES = key[16:]
with Pool(cores) as pool:
iterable = product([G], G)
for S in pool.istarmap(computeSDSP, iterable, chunksize=chunk):
for pair in S:
label, value = pair[0], pair[1]
label_bytes = utils.pair_to_bytes(label)
value_bytes = utils.pair_to_bytes(value)
if label_bytes not in SPDX:
token = DES.tokenGen(key_DES, value_bytes)
ct = utils.SymmetricEncrypt(key_SKE,value_bytes)
ct_value = token + ct
SPDX[label_bytes] = ct_value
return SPDX
Now, for each SDSP of the graph, the label
and value
variables are translated to bytes (a simple translation):
def int_to_bytes(x: int) -> bytes:
return str(x).encode()
def pair_to_bytes(pair: Tuple[int, int]) -> bytes:
return int_to_bytes(pair[0]) + b',' + int_to_bytes(pair[1])
Then, it uses the label (label_bytes
) as the key of a dictionary (SPDX
). The corresponding dictionary value is encrypted.
To encrypt the dictionary value, first a token
is generated on value_bytes
using key_DES
(key[16:]
from the first generated key
):
class DESClass:
# ...
def tokenGen(self, key: bytes, label: bytes) -> bytes:
'''
Input: A key and a label
Output: A token on label
'''
K1 = utils.HashMAC(key, b'1'+label)[:16]
K2 = utils.HashMAC(key, b'2'+label)[:16]
return K1 + K2
# ...
Where utils.HashMAC
just implements HMAC with SHA256:
def HashMAC(key: bytes, plaintext: bytes) -> bytes:
'''
Input: Key and plaintext
Output: A token on plaintext with the key using HMAC
'''
token = HMAC.new(key, digestmod=SHA256)
token.update(bytes(plaintext))
return token.digest()
So, we can assume that token
is uniquely related to value_bytes
(which means that for every combination of (next, node)
, there is a single associated token
). Observe that token
is 32-byte long.
Continuing with encryption, value_bytes
is encrypted with AES using key_SKE
(key[:16]
) as key. See utils.SymmetricEncrypt
/ utils.SymmetricDecrypt
:
def SymmetricEncrypt(key: bytes, plaintext: bytes) -> bytes:
'''
Encrypt the plaintext using AES-CBC mode with provided key.
Input: 16-byte key and plaintext
Output: Ciphertext
'''
if len(key) != 16:
raise ValueError
cipher = AES.new(key, AES.MODE_CBC)
ct = cipher.encrypt(pad(plaintext, AES.block_size))
iv = cipher.iv
return ct + iv
def SymmetricDecrypt(key: bytes, ciphertext: bytes) -> bytes:
'''
Decrypt the ciphertex using AES-CBC mode with provided key.
Input: 16-byte key and ciphertext
Output: Plaintext
'''
if len(key) != 16:
raise ValueError
ct, iv = ciphertext[:-16], ciphertext[-16:]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(ct), AES.block_size)
return pt
Notice that the IV of the AES CBC cipher is appended to the ciphertext at the end after encryption, and taken from the ciphertext for decryption.
Finally, the value that is added to SPDX[label_bytes]
is token + ct
.
Ok, but we have not finished yet. Recall GESClass.encryptGraph
, we have only computed SPDX
:
class GESClass:
# ...
def encryptGraph(self, key: bytes, G: nx.Graph) -> dict[bytes, bytes]:
'''
Input: Secret key and a graph G
Output: Encrypted graph encrypted_db
'''
SPDX = computeSPDX(key, G, self.cores)
key_DES = key[16:]
EDB = DES.encryptDict(key_DES, SPDX, self.cores)
del(SPDX)
gc.collect()
return EDB
# ...
The next step is to take key_DES
(key[16:]
) and call DESClass.encryptDict
to actually encrypt the whole dictionary SPDX
, which is the value to return:
class DESClass:
# ...
def encryptDict(self, key: bytes, plaintext_dx: dict[bytes, bytes], cores: int) -> dict[bytes, bytes]:
'''
Input: A key and a plaintext dictionary
Output: An encrypted dictionary EDX
'''
encrypted_db = {}
chunk = int(len(plaintext_dx)/cores)
iterable = product([key], plaintext_dx.items())
with Pool(cores) as pool:
for ct_label, ct_value in pool.istarmap(encryptDictHelper, iterable, chunksize=chunk):
encrypted_db[ct_label] = ct_value
return encrypted_db
# ...
Again, a bit weird to read, but is very similar to before. It is calling encryptDictHelper
on (key, label_bytes, token + ct)
for each label_bytes
- token + ct
mapping in the dictionary.
And encryptDictHelper
does the following operations:
def encryptDictHelper(key, dict_item):
label = dict_item[0]
value = dict_item[1]
K1 = utils.HashMAC(key, b'1'+label)[:16]
K2 = utils.HashMAC(key, b'2'+label)[:16]
ct_label = utils.Hash(K1)
ct_value = utils.SymmetricEncrypt(K2, value)
return ct_label, ct_value
Notice that K1
is token[:16]
and K2
is token[16:]
. Therefore, it just outputs the SHA256 hash of K1
and encrypts token + ct
with AES using K2
as key. Those will be the key-value pairs of the encrypted dictionary EDB
.
Understanding queries
There are still two methods we have not talked about yet: GESClass.search
and DESClass.search
. The first one calls the second one under the hood, and it is used to find the shortest path between two nodes of the graph:
class GESClass:
# ...
def search(self, token: bytes, encrypted_db: dict[bytes, bytes]) -> Tuple(bytes, bytes):
'''
Input: Search token
Output: (tokens, cts)
'''
resp, tok = b"", b""
curr = token
while True:
value = DES.search(curr, encrypted_db)
if value == b'':
break
curr = value[:32]
resp += value[32:]
tok += curr
return tuple([tok, resp])
# ...
class DESClass:
# ...
def search(self, search_token: bytes, encrypted_db: dict[bytes, bytes]) -> bytes:
'''
Input: Search token and EDX
Output: The corresponding encrypted value.
'''
K1 = search_token[:16]
K2 = search_token[16:]
hash_val = utils.Hash(K1)
if hash_val in encrypted_db:
ct_value = encrypted_db[hash_val]
return utils.SymmetricDecrypt(K2, ct_value)
else:
return b''
# ...
Basically, to search the shortest path from A
to B
, a token
is generated (which identifies the label (A, B)
) and it is searched on EDB
(also named encrypted_db
). Remember that token = K1 + K2
, so the hash of K1
is the key of the dictionary item and K2
is the AES key to decrypt the value of the dictionary item. Only the value is returned from DESClass.search
.
Imagine that the next node from A
of the shortest path is C
, then the output of DESClass.search
in this example would be token + ct
. This token
would correspond to (C, B)
and ct
is encrypted with AES and key_SKE
, which holds precisely (C, B)
.
For each result of DESClass.search
, the result is splitted into token
and ct
and they are both appended to tok
and resp
, respectively. Finally, the tuple (tok, resp)
is returned from GESClass.search
.
Summary
GESClass.keyGen
: Generateskey
askey_SKE
andkey_DES
. Then,key[:16]
is employed for AES encryption andkey[16:]
for HMAC- Single-destination shortest path (SDSP): It represents all the shortest paths from all nodes to a single destination. Each entry is represented as
((node, root), (next, root))
, so that they can be chained together to build a shortest path fromnode
toroot
SPDX
: It is a dictionary that maps(node, root)
withtoken + ct
(related to(next, root)
)DESClass.tokenGen
(andGESClass.tokenGen
): Defines a unique token for a label(node, root)
. It is a 32-byte value composed of 16-byte valuesK1
andK2
EDB
(encrypted_db
): It is a dictionary that maps the SHA256 hash ofK1
and the AES-encrypted value oftoken + ct
usingK2
as keyGESClass.search
: Searches the shortest path from a source to a destination in the SDSP and returns the concatenation of alltoken
(associated to(node, root)
) of the path and all ciphertexts (actually,token + ct
)
cryptoGRAPHy 1
We are given the source code of the server, which uses some functions from the Graph Encryption Scheme library:
from lib import GES, utils
import networkx as nx
import random
from SECRET import flag, decrypt
NODE_COUNT = 130
EDGE_COUNT = 260
SECURITY_PARAMETER = 16
def gen_random_graph(node_count: int, edge_count: int) -> nx.Graph:
nodes = [i for i in range(1, node_count + 1)]
edges = []
while len(edges) < edge_count:
u, v = random.choices(nodes, k=2)
if u != v and (u, v) not in edges and (v, u) not in edges:
edges.append([u, v])
return utils.generate_graph(edges)
if __name__ == '__main__':
try:
print("[+] Generating random graph...")
G = gen_random_graph(NODE_COUNT, EDGE_COUNT)
myGES = GES.GESClass(cores=4, encrypted_db={})
key = myGES.keyGen(SECURITY_PARAMETER)
print(f"[*] Key: {key.hex()}")
print("[+] Encrypting graph...")
enc_db = myGES.encryptGraph(key, G)
print("[!] Answer 50 queries to get the flag. In each query, input the shortest path \
decrypted from response. It will be a string of space-separated nodes from \
source to destination, e.g. '1 2 3 4'.")
for q in range(50):
while True:
u, v = random.choices(list(G.nodes), k=2)
if nx.has_path(G, u, v):
break
print(f"[+] Query {q+1}/50: {u} {v}")
token = myGES.tokenGen(key, (u, v))
_, resp = myGES.search(token, enc_db)
print(f"[*] Response: {resp.hex()}")
ans = input("> Original query: ").strip()
if ans != decrypt(u, v, resp, key):
print("[!] Wrong answer!")
exit()
print(f"[+] Flag: {flag}")
except:
exit()
Notice that there is a function called decrypt
that is hidden to us.
Source code analysis
First, the server generates a random graph, generates a key (which is printed out!) and encrypts the graph:
print("[+] Generating random graph...")
G = gen_random_graph(NODE_COUNT, EDGE_COUNT)
myGES = GES.GESClass(cores=4, encrypted_db={})
key = myGES.keyGen(SECURITY_PARAMETER)
print(f"[*] Key: {key.hex()}")
print("[+] Encrypting graph...")
enc_db = myGES.encryptGraph(key, G)
After that, we are asked to answer 50 queries correctly to get the flag. For each query, the server will take two random nodes (u
and v
) of the graph that are connected and will find the shortest path between them. We will be given the token
for (u, v)
and the resp
part of GESClass.search
(the ciphertexts):
for q in range(50):
while True:
u, v = random.choices(list(G.nodes), k=2)
if nx.has_path(G, u, v):
break
print(f"[+] Query {q+1}/50: {u} {v}")
token = myGES.tokenGen(key, (u, v))
_, resp = myGES.search(token, enc_db)
print(f"[*] Response: {resp.hex()}")
ans = input("> Original query: ").strip()
if ans != decrypt(u, v, resp, key):
print("[!] Wrong answer!")
exit()
print(f"[+] Flag: {flag}")
Solution
Recall from the Graph Encryption Scheme library analysis that resp
is AES-encrypted using key_SKE
as key, which is key[:16]
(we know it from the beginning). So, we can find all intermediate nodes from u
to v
(including v
but not u
) by decrypting every 32-byte chunk (16 for ciphertext and 16 for IV). The value of u
is printed by the server, so it is not a problem here.
We can implement this easily with Python. This is the relevant code:
io = get_process()
io.recvuntil(b'[*] Key: ')
key = bytes.fromhex(io.recvline().decode())
SKE, DES = key[:16], key[16:]
round_prog = io.progress('Round')
for r in range(50):
round_prog.status(f'{r + 1} / 50')
io.recvuntil(b'/50: ')
start = io.recvuntil(b' ')
io.recvuntil(b'[*] Response: ')
res = bytes.fromhex(io.recvline().decode())
ct_iv = [(res[i:i+16], res[i+16:i+32]) for i in range(0, len(res), 32)]
shortest_path = []
for ct, iv in ct_iv:
cipher = AES.new(SKE, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(ct), AES.block_size)
shortest_path.append(pt.split(b',')[0])
io.sendlineafter(b'> Original query: ', start + b' '.join(shortest_path))
round_prog.success('50 / 50')
print(io.recv().decode().strip())
Flag
Once we run the script, we will find the flag:
$ python3 solve.py chals.sekai.team 3001
[+] Opening connection to chals.sekai.team on port 3001: Done
[+] Round: 50 / 50
[+] Flag: SEKAI{GES_15_34sy_2_br34k_kn@w1ng_th3_k3y}
[*] Closed connection to chals.sekai.team port 3001
The full script can be found in here: solve.py
.
cryptoGRAPHy 2
For this second part, we are given the source code of the server, which again uses some functions from the Graph Encryption Scheme library:
from lib import GES, utils
import networkx as nx
import random
from SECRET import flag, get_SDSP_node_degrees
'''
get_SDSP_node_degrees(G, dest) returns the node degrees in the single-destination shortest path (SDSP) tree, sorted in ascending order.
For example, if G has 5 nodes with edges (1,2),(1,3),(2,3),(2,5),(4,5) and dest=1, returns "1 1 2 2 2".
[+] Original: [+] SDSP:
1--2--5--4 1--2--5--4
| / |
3 3
'''
# Another example for sanity check
TestGraph = utils.generate_graph([[1, 2], [1, 4], [1, 6], [6, 5], [6, 7], [4, 7], [2, 5]])
assert get_SDSP_node_degrees(TestGraph, 1) == '1 1 1 2 2 3'
NODE_COUNT = 130
EDGE_PROB = 0.031
SECURITY_PARAMETER = 32
def gen_random_graph() -> nx.Graph:
return nx.fast_gnp_random_graph(n=NODE_COUNT, p=EDGE_PROB)
if __name__ == '__main__':
try:
print("[!] Pass 10 challenges to get the flag:")
for q in range(10):
print(f"[+] Challenge {q+1}/10. Generating random graph...")
while True:
G = gen_random_graph()
if nx.is_connected(G):
break
myGES = GES.GESClass(cores=4, encrypted_db={})
key = myGES.keyGen(SECURITY_PARAMETER)
print("[+] Encrypting graph...")
enc_db = myGES.encryptGraph(key, G)
dest = random.choice(list(G.nodes()))
print(f"[*] Destination: {dest}")
attempts = NODE_COUNT
while attempts > 0:
attempts -= 1
query = input("> Query u,v: ").strip()
try:
u, v = map(int, query.split(','))
assert u in G.nodes() and v in G.nodes() and u != v
except:
print("[!] Invalid query!")
break
token = myGES.tokenGen(key, (u, v))
print(f"[*] Token: {token.hex()}")
tok, resp = myGES.search(token, enc_db)
print(f"[*] Query Response: {tok.hex() + resp.hex()}")
ans = input("> Answer: ").strip()
if ans != get_SDSP_node_degrees(G, dest):
print("[!] Wrong answer!")
exit()
print(f"[+] Flag: {flag}")
except:
exit()
SDSP and node degrees
Now, the server hides a function called get_SDSP_node_degrees
, but at least it describes what it does:
from SECRET import flag, get_SDSP_node_degrees
'''
get_SDSP_node_degrees(G, dest) returns the node degrees in the single-destination shortest path (SDSP) tree, sorted in ascending order.
For example, if G has 5 nodes with edges (1,2),(1,3),(2,3),(2,5),(4,5) and dest=1, returns "1 1 2 2 2".
[+] Original: [+] SDSP:
1--2--5--4 1--2--5--4
| / |
3 3
'''
# Another example for sanity check
TestGraph = utils.generate_graph([[1, 2], [1, 4], [1, 6], [6, 5], [6, 7], [4, 7], [2, 5]])
assert get_SDSP_node_degrees(TestGraph, 1) == '1 1 1 2 2 3'
This challenge made me understand exactly the concept of an SDSP. The drawn example might be a bit short to generalize, so let’s draw the second example. This is the graph:
4----1----2
| | |
7----6----5
Now, the destination is 1
, so let’s find shortest paths from every node:
2
->1
:
4----1----2
| | |
7----6----5
4
->1
:
4----1----2
| | |
7----6----5
5
->1
:
4----1----2
| | |
7----6----5
6
->1
:
4----1----2
| | |
7----6----5
7
->1
:
4----1----2
| | |
7----6----5
So, the SDSP is the initial graph but removing the unused edges, and it is actually a tree:
4----1----2
| | |
7 6 5
And the output of get_SDSP_node_degrees
is a list of the degrees of each node of the SDSP (the degree of a node is the number of adjacent nodes of such node), sorted in ascending order. Therefore, the result for this example is 1 1 1 2 2 3
. We can replace the node identifiers by their corresponding degree, to illustrate the concept of the degree of a node:
2----3----2
| | |
1 1 1
Source code analysis
This time, the server does not print the key
… We will need to obtain the output of get_SDSP_node_degrees
a total of 10 times.
The initialization is the same as before (generate random graph, generate key and encrypt graph). After that, the server takes a random destination point, which is printed out in plaintext:
print("[!] Pass 10 challenges to get the flag:")
for q in range(10):
print(f"[+] Challenge {q+1}/10. Generating random graph...")
while True:
G = gen_random_graph()
if nx.is_connected(G):
break
myGES = GES.GESClass(cores=4, encrypted_db={})
key = myGES.keyGen(SECURITY_PARAMETER)
print("[+] Encrypting graph...")
enc_db = myGES.encryptGraph(key, G)
dest = random.choice(list(G.nodes()))
print(f"[*] Destination: {dest}")
The server allows us to query up to 130 times (NODE_COUNT
, the number of nodes in the graph), and we will be given the token
for our input (u, v)
, as well as tok
and resp
from GESClass.search
(concatenated):
attempts = NODE_COUNT
while attempts > 0:
attempts -= 1
query = input("> Query u,v: ").strip()
try:
u, v = map(int, query.split(','))
assert u in G.nodes() and v in G.nodes() and u != v
except:
print("[!] Invalid query!")
break
token = myGES.tokenGen(key, (u, v))
print(f"[*] Token: {token.hex()}")
tok, resp = myGES.search(token, enc_db)
print(f"[*] Query Response: {tok.hex() + resp.hex()}")
ans = input("> Answer: ").strip()
if ans != get_SDSP_node_degrees(G, dest):
print("[!] Wrong answer!")
exit()
print(f"[+] Flag: {flag}")
Solution
So, we will try to find the SDSP to a given destination node. Therefore, all the queries to do must be (node, root)
, where root
is the given destination point.
Also, notice that there is a check u != v
. We can use this to break the while
loop by sending (root, root)
at the end, when we already have the other 129 query results, and then try to guess the output of getSDSP_node_degrees
.
For each round, we will do the following:
data = {}
nodes = [0] * NODES_SIZE
io.recvuntil(b'[*] Destination: ')
dest = int(io.recvline().decode())
for i in range(NODES_SIZE):
if i == dest:
continue
io.sendlineafter(b'> Query u,v: ', f'{i},{dest}'.encode())
io.recvuntil(b'[*] Token: ')
token = bytes.fromhex(io.recvline().decode())
io.recvuntil(b'[*] Query Response: ')
res = bytes.fromhex(io.recvline().decode())
keys = [token] + [res[i:i+32] for i in range(0, len(res) // 2, 32)]
data[i] = list(map(lambda b: b.hex(), keys))
if nodes[dest] == 0:
nodes[dest] = keys[-1].hex()
if nodes[i] == 0:
nodes[i] = keys[0].hex()
We are taking the token
, which is related to (node, root)
, and then we take the token
part of the response from the server (the first half are all the 32-byte tokens for each (next, root)
label in the shortest path). We will save them into a data
dictionary for each node.
Also, we will use the token
to identify each node (saved in the nodes
list), and one of the last token
values from res
as the identifier for the destination ((root, root)
).
Once we have this information, we can generate the SDSP graph by saving all possible edges into a set (following shortest paths):
edges = set()
for d in data.values():
prev_node = nodes.index(d[0])
for i, next_node in enumerate(d[1:] + [nodes[0]]):
edge = (prev_node, nodes.index(next_node))
if (edge[1], edge[0]) not in edges and edge[0] != edge[1] and edge != (dest, 0):
edges.add(edge)
prev_node = nodes.index(next_node)
G = generate_graph(edges)
And the last step is to take the degrees of the nodes and sort them:
degrees = sorted(map(lambda t: t[1], G.degree))
io.sendlineafter(b'> Query u,v: ', f'{dest},{dest}'.encode())
io.sendlineafter(b'> Answer: ', ' '.join(map(str, degrees)).encode())
Flag
We can run the script to compute the above 10 times to get the flag:
$ python3 solve.py chals.sekai.team 3062
[+] Opening connection to chals.sekai.team on port 3062: Done
[+] Round: 10 / 10
[+] Flag: SEKAI{3ff1c13nt_GES_4_Shortest-Path-Queries-_-}
[*] Closed connection to chals.sekai.team port 3062
The full script can be found in here: solve.py
.
cryptoGRAPHy 3
For the last challenge of the series, we are given the source code of the server again:
from itertools import product, chain
from multiprocessing import Pool
from lib import GES
import networkx as nx
import random
import time
from SECRET import flag, generate_tree, decrypt
NODE_COUNT = 60
SECURITY_PARAMETER = 128
MENU = '''============ MENU ============
1. Graph Information
2. Query Responses
3. Challenge
4. Exit
=============================='''
def query_resps(cores: int, key: bytes, G: nx.Graph, myGES: GES.GESClass, enc_db):
n = len(G.nodes())
query_list = []
queries = product(set(), set())
for component in nx.connected_components(G):
queries = chain(queries, product(component, component))
iterable = product([key], queries)
chunk = n * n // cores
with Pool(cores) as pool:
for token in pool.istarmap(myGES.tokenGen, iterable, chunksize=chunk):
tok, resp = myGES.search(token, enc_db)
query_list.append((token.hex() + tok.hex(), resp.hex()))
random.shuffle(query_list)
return query_list
if __name__ == '__main__':
try:
G = generate_tree()
assert len(G.nodes()) == NODE_COUNT
myGES = GES.GESClass(cores=4, encrypted_db={})
key = myGES.keyGen(SECURITY_PARAMETER)
enc_db = myGES.encryptGraph(key, G)
t = time.time()
print("[!] Recover 10 queries in 30 seconds. It is guaranteed that each answer is unique.")
while True:
print(MENU)
option = input("> Option: ").strip()
if option == "1":
print("[!] Graph information:")
print("[*] Edges:", G.edges())
elif option == "2":
print(f"[*] Query Responses: ")
resp = query_resps(4, key, G, myGES, enc_db)
for r in resp:
print(f"{r[0]} {r[1]}")
elif option == "3":
break
else:
exit()
print("[!] In each query, input the shortest path decrypted from response. \
It will be a string of space-separated nodes from source to destination, e.g. '1 2 3 4'.")
for q in range(10):
print(f"[+] Challenge {q+1}/10.")
while True:
u, v = random.choices(list(G.nodes()), k=2)
if u != v and nx.has_path(G, u, v):
break
token = myGES.tokenGen(key, (u, v))
print(f"[*] Token: {token.hex()}")
tok, resp = myGES.search(token, enc_db)
print(f"[*] Query Response: {tok.hex() + resp.hex()}")
ans = input("> Original query: ").strip()
if ans != decrypt(u, v, resp, key):
print("[!] Wrong answer!")
exit()
if time.time() - t > 30:
print("[!] Time's up!")
exit()
print(f"[+] Flag: {flag}")
except:
exit()
Source code analysis
Again, the server does not print the key
… But instead, we are given the initial graph in plaintext (option 1
):
if option == "1":
print("[!] Graph information:")
print("[*] Edges:", G.edges())
Particularly, the graph is a tree that is generated with a hidden function called generate_tree
. At the beginning, this was a bit weird to me, since networkx
has a function named random_tree
, so there is no reason to hide it. When doing some tests in local, I found out that if two nodes are at the same level, the challenge is not solvable because those nodes can’t be distinguished when they are encrypted. Then, looking at the tree generated by the server, there was no pair of nodes at the same level, so the challenge was solvable and the hidden function generate_tree
discarded that situation.
In option 2
, we will be given the responses ((tok, resp)
) to all possible queries in the graph, which has a total of 60 nodes. Therefore, we will be given a total of $60^2 = 3600$ responses. The problem is that the responses list is shuffled:
def query_resps(cores: int, key: bytes, G: nx.Graph, myGES: GES.GESClass, enc_db):
n = len(G.nodes())
query_list = []
queries = product(set(), set())
for component in nx.connected_components(G):
queries = chain(queries, product(component, component))
iterable = product([key], queries)
chunk = n * n // cores
with Pool(cores) as pool:
for token in pool.istarmap(myGES.tokenGen, iterable, chunksize=chunk):
tok, resp = myGES.search(token, enc_db)
query_list.append((token.hex() + tok.hex(), resp.hex()))
random.shuffle(query_list)
return query_list
Having received all the information, we can use option 3
to begin the actual challenge: for 10 rounds, the server wil take two random nodes u
and v
of the graph that are connected and it will request us to enter the original query (that is, the shortest path between them). We will be given the token
for (u, v)
and the output of GESClass.search
:
print("[!] In each query, input the shortest path decrypted from response. \
It will be a string of space-separated nodes from source to destination, e.g. '1 2 3 4'.")
for q in range(10):
print(f"[+] Challenge {q+1}/10.")
while True:
u, v = random.choices(list(G.nodes()), k=2)
if u != v and nx.has_path(G, u, v):
break
token = myGES.tokenGen(key, (u, v))
print(f"[*] Token: {token.hex()}")
tok, resp = myGES.search(token, enc_db)
print(f"[*] Query Response: {tok.hex() + resp.hex()}")
ans = input("> Original query: ").strip()
if ans != decrypt(u, v, resp, key):
print("[!] Wrong answer!")
exit()
if time.time() - t > 30:
print("[!] Time's up!")
exit()
print(f"[+] Flag: {flag}")
Solution
First of all, we can take all the information from options 1
and 2
:
io = get_process()
io.sendlineafter(b'> Option: ', b'1')
io.recvuntil(b'[*] Edges: ')
G = generate_graph(eval(io.recvline().decode()))
io.sendlineafter(b'> Option: ', b'2')
io.recvline()
queries = []
while True:
line = io.recvline().decode()[:-1]
if 'MENU' in line:
break
tok, res = map(bytes.fromhex, line.split(' '))
queries.append({
'token': tok[:32].hex(),
'tok': [tok[i:i+32].hex() for i in range(32, len(tok), 32)],
'res': [res[i:i+32].hex() for i in range(0, len(res), 32)]
})
Notice how I used a queries
list (that will contain a total of 3600 elements) with a token
, and two lists for tok
and res
.
Taking a look at the queries
list, we can find out that there are exactly 60 elements with empty tok
and res
lists since those queries are of the form (root, root)
. We can use these elements to identify nodes of the graph, so let’s separate them into another list:
toks_0 = [q['token'] for q in queries if len(q['tok']) == 0]
Actually, each token in tok_0
will appear in other elements of queries.tok
lists at the last position, since (root, root)
is a destination. As a result, we can find all token
of the form (node, root)
that contain a given token
from tok_0
in their tok
list.
And what’s more interesting, we can join all token
together in a graph. Actually, this graph will be a tree. This is very important, because trees do not contain any cycle, so the path between two nodes is actually the shortest path between them. And what’s even more mindblowing, we can find a mapping between the token
tree and the plaintext tree we have from the beginning.
In mathematical terms, this is called a tree isomorphism, which defines a one-to-one relation between the nodes of two trees. Fortunately, networkx
provides a function to find a tree isomorphism (tree_isomorphism
), so we can match each token
to a plaintext node:
mappings = {}
def define_tree(queries):
nodes = [q['token'] for q in queries]
edges = set()
for node in nodes:
for q in queries:
if len(q['tok']) and q['tok'][0] == node:
edges.add((node, q['token']))
GG = generate_graph(edges)
isomorphism = tree_isomorphism(GG, G)
for enc, node in isomorphism:
mappings[enc] = node
And this process must be done 60 times, since there are a total of 60 different destinations:
for tok in toks_0:
define_tree([q for q in queries if tok in q['tok'] or tok == q['token']])
Once we have the mapping between encrypted token
and plaintext node, we can easily find the shortest path of the query we are given by searching in the mapping
dictionary. And that is precisely what we are doing here:
io.sendlineafter(b'> Option: ', b'3')
round_prog = io.progress('Round')
for r in range(10):
round_prog.status(f'{r + 1} / 10')
io.recvuntil(b'[*] Token: ')
token = bytes.fromhex(io.recvline().decode())
io.recvuntil(b'[*] Query Response: ')
res = bytes.fromhex(io.recvline().decode())
keys = [token.hex()] + [res[i:i+32].hex() for i in range(0, len(res) // 2, 32)]
shortest_path = []
for key in keys:
shortest_path.append(mappings[key])
io.sendlineafter(b'> Original query: ', ' '.join(map(str, shortest_path)).encode())
round_prog.success('10 / 10')
print(io.recv().decode().strip())
Flag
Once we run the script we will see the flag (which points to a paper that describes this query recovery attack in much more detail):
$ python3 solve.py chals.sekai.team 3023
[+] Opening connection to chals.sekai.team on port 3023: Done
[+] Round: 10 / 10
[+] Flag: SEKAI{Full_QR_Attack_is_not_easy_https://eprint.iacr.org/2022/838.pdf}
[*] Closed connection to chals.sekai.team port 3023
The full script can be found in here: solve.py
.