In this challenge, we are given a simple block cipher (though not fully invertible), based on 3 modular multiplications.
# Sage mode
P = 247359019496198933 # 2**57.78
C = 223805275076627807 # 2**57.64
M = 2**60
K0 = random.randint(1, P-1)
K1 = random.randint(1, P-1)
# not a bijection? can be adjusted but I'm lazy
def encrypt_block(x):
tmp = x * K0 % P
tmp = tmp * C % M
tmp = tmp * K1 % P
return tmp
We are also given $2^{24}$ random plaintext-ciphertext pairs and encrypted flag.
Let's build linearized relation between a plaintext and the ciphertext, by introducing variables for modular reductions.
Consider an encryption $(x, y)$.
Let
$$
\begin{align}
t_1 &= \lfloor K_0x / P \rfloor < x, \\
t_2 &= \lfloor (K_0x\mod{P}) / M \rfloor < PC/M, \\
t_3 &= \lfloor K_1^{-1}y / P \rfloor + \epsilon < y + \epsilon, \\
\end{align}
$$
where $0 \le \epsilon \le \lfloor M / P \rfloor = 4$ is such that $yK_1^{-1} -t_3P$ matches the plaintext side second step encryption (tmp = tmp * C % M
). This little difference happens since the value after the second step can be large up to $M=2^{60}$ and then it is reduced modulo $P\approx 2^{57.64}$, so a couple bits of information are lost.
Then, following two steps of encryption of $x$ and one step of decryption of $y$ we get: $$ \begin{align} & (K_0x - t_1P)C - t_2M = K_1^{-1}y - t_3P, \\ \Rightarrow & xC\cdot K_0 - y\cdot K_1^{-1} - PC\cdot t_1 - M\cdot t_2 + P\cdot t_3 = 0, \end{align} $$ where we have unknowns $$ \begin{align} 0 \le~ & K_0 < P,\\ 0 \le~ & K_1^{-1} < P,\\ 0 \le~ & t_1 < x,\\ 0 \le~ & t_2 < PC/M,\\ 0 \le~ & t_3 < y+\epsilon. \end{align} $$
Note that $x,y,t_1,t_2,t_3$ are different for each known data pair.
We can now use LLL to solve this constraint system. As an example, consider the following lattice for $n=3$ data pairs (rows as vectors): $$ \begin{matrix} &~~~~ eq_0 ~~~~~~ eq_1 ~~~~~~ eq_2 ~~~ K_0 ~~~ K_1 ~~~ . ~~~~ t_{1,i} ~~~ . ~~~~ . ~~~~ t_{2,i} ~~~ . ~~~~ . ~~~~ t_{3,i} ~~~ . \hfill \\ \begin{matrix} K_0 \\ K_1 \\ . \\ t_{1,i} \\ . \\ . \\ t_{2,i} \\ . \\ . \\ t_{3,i} \\ . \end{matrix} \hspace{-1em}& \begin{pmatrix} x_0C & x_1C & x_2C & 1 & & & & & & & & & & \\ y_0 & y_1 & y_2 & & 1 & & & & & & & & & \\ CP & & & & & 1 & & & & & & & & \\ & CP & & & & & 1 & & & & & & & \\ & & CP & & & & & 1 & & & & & & \\ M & & & & & & & & 1 & & & & & \\ & M & & & & & & & & 1 & & & & \\ & & M & & & & & & & & 1 & & & \\ P & & & & & & & & & & & 1 & & \\ & P & & & & & & & & & & & 1 & \\ & & P & & & & & & & & & & & 1 \\ \end{pmatrix} \end{matrix} $$
We are looking for a linear combination of rows that makes the first $n$ entries zero, and the others to respect our bounds. We can achieve this by scaling the coordinates (columns): first $n$ columns should be multiplied by a very large number (forcing LLL to make it zero), the other columns should be multiplied inversely to their bounds. After applying the LLL, we need to scale back.
f = open("res")
data = []
for line in f:
try:
pt, ct = map(int, line.split())
except:
break
data.append((pt, ct))
Note that the bounds of $t_{1,i},t_{2,i}$ depend on the actual plaintexts. We are thus interested in smallest plaintexts and ciphertexts. As we shall see, 20 smallest pairs are enough!
data.sort(key=lambda a: a[0]**2 + a[1]**2)
n = 20
pairs = data[:20]
m = matrix(QQ, 2 + 3*n, 2 + 4*n)
m[:,n:] = identity_matrix(2+3*n)
for i, (x, y) in enumerate(pairs):
m[0,i] = C*x
m[1,i] = y
m[0*n+2+i,i] = C*P
m[1*n+2+i,i] = M
m[2*n+2+i,i] = P
bounds = [1] * n + [P] * 2
bounds += [pt for pt, ct in pairs]
bounds += [P*C / M] * n
bounds += [ct for pt, ct in pairs]
assert len(bounds) == m.ncols()
# scale
for i, b in enumerate(bounds):
m.set_column(i, m.column(i)/b)
# LLL
m = m.LLL()
for i, b in enumerate(bounds):
m.set_column(i, m.column(i)*b)
for irow, row in enumerate(m):
k0, negk1i = row[n:n+2]
if gcd(negk1i, P) == 1:
k1 = inverse_mod(-int(negk1i), P)
for x, y in pairs:
tmp = x * k0 % P
tmp = tmp * C % M
tmp = tmp * k1 % P
if tmp != y:
break
else:
k0 %= P
k1 %= P
print("Row %d: key recovered: %x %x" % (irow, k0, k1))
break
Row 11: key recovered: 1df19a439748567 29ad0f3aac513b9
That was fast! Now let's decrypt the flag (recall that a couple of bits is missing, so we have to check a few candidates per block). Also the first reduction modulo $P$ destroys a few bits too.
k0i = inverse_mod(k0, P)
k1i = inverse_mod(k1, P)
Ci = inverse_mod(C, M)
def decrypt_block(y):
for t in range(5):
v = y * k1i % P
v += t*P
if v >= M: continue
v = v * Ci % M
if v >= P: continue
v = v * k0i % P
while v < 2**64:
#assert encrypt_block(v) == y
yield v
v += P
from struct import unpack, pack
ct = open("res").read()[-200:].strip().split()[-1]
ct = bytes.fromhex(ct)
fmt = '%dQ' % (len(ct)/8)
ct = unpack(fmt, ct)
flag = b""
for block in ct:
for dec in decrypt_block(block):
dec = pack("<Q", dec)
if all(0x20 <= v < 127 for v in dec):
print(dec)
flag += dec
print(b"flag{" + flag + b"}")
b'_p4droNe' b'_a5k3d_m' b'3_7o_br1' b'ng_tHis_' b'foR_thE_' b'5ign0Ra.' b'flag{_p4droNe_a5k3d_m3_7o_br1ng_tHis_foR_thE_5ign0Ra.}'