Hackvent 2024: Day 16

Hackvent 202460

[HV24.16] Santa's Signatures

Introduction
Because of new bureaucratic regulations, Santa has to sign every package that he sends out. So far, he has always used his drawing pad to sign manually but lately, he has been getting hand cramps and the doctor recommended him to try out digital signatures. Thus, he has tasked one of his elves to implement such a system and has published 4 digital signatures of his favourite lyrics to the world. Unfortunately, you didn't have the time to ask him for more samples...

Analyze the code and get the flag.
Flag format: HV24{}
sha256sum of Santa's Signatures: f9b84f2f41ee4a8a9baa9029cf66a9b17e43fc83d651d63ea75de424dd21ade3 sha256sum of output.txt: a1c122fba8daa6d4d9d47f353815ec0d4225446ec468b4399644741af62c0286

This challenge was written by kuyaya. Very glad that we're in a CTF era where we don't have to understand post-quantum cryptography yet.

Solution

Initial Analysis

We are presented with two files.

The first is santas-signatures.py:

from ecdsa import SigningKey, NIST192p
from hashlib import sha256
import os

from hackvent import flag

private_key = SigningKey.generate(curve=NIST192p)
public_key = private_key.get_verifying_key()

curve = NIST192p
n = curve.order

message = b"""
We're no strangers to love
You know the rules and so do I (do I)
A full commitment's what I'm thinking of
You wouldn't get this from any other guy

I just wanna tell you how I'm feeling
Gotta make you understand

Never gonna give you up
Never gonna run around
Never gonna make you cry
Never gonna tell a lie and hurt you

We've known each other for so long
Your heart's been aching, but you're too shy to say it (say it)
Inside, we both know what's been going on
We know the game, and we're gonna play it

And if you ask me how I'm feeling
Don't tell me you're too blind to see

Never gonna give you up
Never gonna let you down
Never gonna run around and desert you
Never gonna make you cry
Never gonna say goodbye
Never gonna tell a lie and hurt you

Never gonna give, never gonna give
Never gonna give, never gonna give
Never gonna give, never gonna give (give you up)
(Ooh) never gonna give, never gonna give (give you up)

I just wanna tell you how I'm feeling
Gotta make you understand
Never gonna, never gonna

Never gonna give you up
Never gonna let you down
Never gonna run around and desert you
Never gonna make you cry
Never gonna say goodbye
Never gonna tell a lie and hurt you

Never gonna give you up
Never gonna let you down
Never gonna run around and desert you
Never gonna make you cry
Never gonna say goodbye
Never gonna tell a lie and hurt you
"""

r_list = []
s_list = []

h = int.from_bytes(sha256(message).digest(), "big") % n

for _ in range(4):
    k = int.from_bytes(flag + os.urandom(4), "big")
    assert k < n
    r = (k * curve.generator).x() % n
    s = (pow(k, -1, n) * (h + r * private_key.privkey.secret_multiplier)) % n
    r_list.append(int(r))
    s_list.append(int(s))


print("r =", r_list)
print("s =", s_list)

The second is output.txt:

r = [382825619053484650723101111089716481637169498894438388011, 2846338329314931410625679965921020604974471932472870479272, 4539748290341241446856454569550628724992441965649378727404, 941682904620798018129415714406121176743478727872983123639]
s = [1053747182506109288607080233885972025033725041930583121945, 271361922488295908863717359631373504169617539839833749415, 1147747170412930491481269098330085803226817442551773675299, 3831443458083168767818771718543562148023158622090413416724]

Santa is using ECDSA with the NIST192p curve to digitally sign his messages. In this case, we have four signed messages. The same message has been signed four times but with different k values (nonces).
The nonces are set like this:

k = int.from_bytes(flag + os.urandom(4), "big")

This generates unique nonces (due to the random 4 bytes at the end), but the nonces are biased. They are biased because we know the flag is used as a prefix of each nonce and the flag is static. We don't know what the flag is but the format of a flag is HV24{...}. The nonce for the NIST192p curve should be 192 bits in total (although it can be shorter). This means that the flag’s contents should be anywhere from one to fourteen characters in length. As a result, the total flag length (including the HV24{}) should be between seven and twenty characters in length.

ECDSA biased-k lattice attack

Since we know the prefix of the k nonce used to sign the messages and that the prefix of k is HV24{ (MSB), we suspect these signatures are vulnerable to a biased-k lattice attack. A biased-k lattice attack takes advantage of predictable patterns in the nonce (k), allowing attackers to use lattice methods to figure out the private key behind the signature.

However, we only have 40 bits of known MSB which is the HV24{ prefix. This is not enough to perform the attack but 48 bits would be sufficient. Therefore, our plan is to brute-force once extra byte giving us 48 bits of known prefix. The CTF author hints that we should use the crypto-attacks lattice attack to try and solve. We decide to use this code to perform the attack.

First, we spawn a Docker container that contains SageMath, which is required by the script since we do not have it installed locally:

$ docker run -it sagemath/sagemath

Next, we write the following solve script, where the first part is copied and pasted from the crypto-attacks repo, and the final portion is written by the author:

# Hackvent 2024 - Day 16
# Mo Beigi
#
# Perform biased-k latice attack on Santa's Signatures.

# Other imports
import string
from hashlib import sha256
import itertools

from math import ceil

# https://github.com/jvdsn/crypto-attacks/blob/master/shared/partial_integer.py
from math import log2
class PartialInteger:
    """
    Represents positive integers with some known and some unknown bits.
    """

    def __init__(self):
        """
        Constructs a new PartialInteger with total bit length 0 and no components.
        """
        self.bit_length = 0
        self.unknowns = 0
        self._components = []

    def add_known(self, value, bit_length):
        """
        Adds a known component to the msb of this PartialInteger.
        :param value: the value of the component
        :param bit_length: the bit length of the component
        :return: this PartialInteger, with the component added to the msb
        """
        self.bit_length += bit_length
        self._components.append((value, bit_length))
        return self

    def add_unknown(self, bit_length):
        """
        Adds an unknown component to the msb of this PartialInteger.
        :param bit_length: the bit length of the component
        :return: this PartialInteger, with the component added to the msb
        """
        self.bit_length += bit_length
        self.unknowns += 1
        self._components.append((None, bit_length))
        return self

    def get_known_lsb(self):
        """
        Returns all known lsb in this PartialInteger.
        This method can cross multiple known components, but stops once an unknown component is encountered.
        :return: a tuple containing the known lsb and the bit length of the known lsb
        """
        lsb = 0
        lsb_bit_length = 0
        for value, bit_length in self._components:
            if value is None:
                return lsb, lsb_bit_length

            lsb = lsb + (value << lsb_bit_length)
            lsb_bit_length += bit_length

        return lsb, lsb_bit_length

    def get_known_msb(self):
        """
        Returns all known msb in this PartialInteger.
        This method can cross multiple known components, but stops once an unknown component is encountered.
        :return: a tuple containing the known msb and the bit length of the known msb
        """
        msb = 0
        msb_bit_length = 0
        for value, bit_length in reversed(self._components):
            if value is None:
                return msb, msb_bit_length

            msb = (msb << bit_length) + value
            msb_bit_length += bit_length

        return msb, msb_bit_length

    def get_known_middle(self):
        """
        Returns all known middle bits in this PartialInteger.
        This method can cross multiple known components, but stops once an unknown component is encountered.
        :return: a tuple containing the known middle bits and the bit length of the known middle bits
        """
        middle = 0
        middle_bit_length = 0
        for value, bit_length in self._components:
            if value is None:
                if middle_bit_length > 0:
                    return middle, middle_bit_length
            else:
                middle = middle + (value << middle_bit_length)
                middle_bit_length += bit_length

        return middle, middle_bit_length

    def get_unknown_lsb(self):
        """
        Returns the bit length of the unknown lsb in this PartialInteger.
        This method can cross multiple unknown components, but stops once a known component is encountered.
        :return: the bit length of the unknown lsb
        """
        lsb_bit_length = 0
        for value, bit_length in self._components:
            if value is not None:
                return lsb_bit_length

            lsb_bit_length += bit_length

        return lsb_bit_length

    def get_unknown_msb(self):
        """
        Returns the bit length of the unknown msb in this PartialInteger.
        This method can cross multiple unknown components, but stops once a known component is encountered.
        :return: the bit length of the unknown msb
        """
        msb_bit_length = 0
        for value, bit_length in reversed(self._components):
            if value is not None:
                return msb_bit_length

            msb_bit_length += bit_length

        return msb_bit_length

    def get_unknown_middle(self):
        """
        Returns the bit length of the unknown middle bits in this PartialInteger.
        This method can cross multiple unknown components, but stops once a known component is encountered.
        :return: the bit length of the unknown middle bits
        """
        middle_bit_length = 0
        for value, bit_length in self._components:
            if value is None:
                if middle_bit_length > 0:
                    return middle_bit_length
            else:
                middle_bit_length += bit_length

        return middle_bit_length

    def matches(self, i):
        """
        Returns whether this PartialInteger matches an integer, that is, all known bits are equal.
        :param i: the integer
        :return: True if this PartialInteger matches i, False otherwise
        """
        shift = 0
        for value, bit_length in self._components:
            if value is not None and (i >> shift) % (2 ** bit_length) != value:
                return False

            shift += bit_length

        return True

    def sub(self, unknowns):
        """
        Substitutes some values for the unknown components in this PartialInteger.
        These values can be symbolic (e.g. Sage variables)
        :param unknowns: the unknowns
        :return: an integer or expression with the unknowns substituted
        """
        assert len(unknowns) == self.unknowns
        i = 0
        j = 0
        shift = 0
        for value, bit_length in self._components:
            if value is None:
                # We don't shift here because the unknown could be a symbolic variable
                i += 2 ** shift * unknowns[j]
                j += 1
            else:
                i += value << shift

            shift += bit_length

        return i

    def get_known_and_unknowns(self):
        """
        Returns i_, o, and l such that this integer i = i_ + sum(2^(o_j) * i_j) with i_j < 2^(l_j).
        :return: a tuple of i_, o, and l
        """
        i_ = 0
        o = []
        l = []
        offset = 0
        for value, bit_length in self._components:
            if value is None:
                o.append(offset)
                l.append(bit_length)
            else:
                i_ += 2 ** offset * value

            offset += bit_length

        return i_, o, l

    def get_unknown_bounds(self):
        """
        Returns a list of bounds on each of the unknowns in this PartialInteger.
        A bound is simply 2^l with l the bit length of the unknown.
        :return: the list of bounds
        """
        return [2 ** bit_length for value, bit_length in self._components if value is None]

    def to_int(self):
        """
        Converts this PartialInteger to an int.
        The number of unknowns must be zero.
        :return: the int represented by this PartialInteger
        """
        assert self.unknowns == 0
        return self.sub([])

    def to_string_le(self, base, symbols="0123456789abcdefghijklmnopqrstuvwxyz"):
        """
        Converts this PartialInteger to a list of characters in the provided base (little endian).
        :param base: the base, must be a power of two and less than or equal to 36
        :param symbols: the symbols to use, at least as many as base (default: "0123456789abcdefghijklmnopqrstuvwxyz")
        :return: the list of characters, with '?' representing an unknown digit
        """
        assert (base & (base - 1)) == 0, "Base must be power of two."
        assert base <= 36
        assert len(symbols) >= base
        bits_per_element = int(log2(base))
        chars = []
        for value, bit_length in self._components:
            assert bit_length % bits_per_element == 0, f"Component with bit length {bit_length} can't be represented by base {base} digits"
            for _ in range(bit_length // bits_per_element):
                if value is None:
                    chars.append('?')
                else:
                    chars.append(symbols[value % base])
                    value //= base

        return chars

    def to_string_be(self, base, symbols="0123456789abcdefghijklmnopqrstuvwxyz"):
        """
        Converts this PartialInteger to a list of characters in the provided base (big endian).
        :param base: the base, must be a power of two and less than or equal to 36
        :param symbols: the symbols to use, at least as many as base (default: "0123456789abcdefghijklmnopqrstuvwxyz")
        :return: the list of characters, with '?' representing an unknown digit
        """
        return self.to_string_le(base, symbols)[::-1]

    def to_bits_le(self, symbols="01"):
        """
        Converts this PartialInteger to a list of bit characters (little endian).
        :param symbols: the two symbols to use (default: "01")
        :return: the list of bit characters, with '?' representing an unknown bit
        """
        assert len(symbols) == 2
        return self.to_string_le(2, symbols)

    def to_bits_be(self, symbols="01"):
        """
        Converts this PartialInteger to a list of bit characters (big endian).
        :param symbols: the two symbols to use (default: "01")
        :return: the list of bit characters, with '?' representing an unknown bit
        """
        return self.to_bits_le(symbols)[::-1]

    def to_hex_le(self, symbols="0123456789abcdef"):
        """
        Converts this PartialInteger to a list of hex characters (little endian).
        :param symbols: the 16 symbols to use (default: "0123456789abcdef")
        :return: the list of hex characters, with '?' representing an unknown nibble
        """
        assert len(symbols) == 16
        return self.to_string_le(16, symbols)

    def to_hex_be(self, symbols="0123456789abcdef"):
        """
        Converts this PartialInteger to a list of hex characters (big endian).
        :param symbols: the 16 symbols to use (default: "0123456789abcdef")
        :return: the list of hex characters, with '?' representing an unknown nibble
        """
        return self.to_hex_le(symbols)[::-1]

    @staticmethod
    def unknown(bit_length):
        return PartialInteger().add_unknown(bit_length)

    @staticmethod
    def parse_le(digits, base):
        """
        Constructs a PartialInteger from arbitrary digits in a provided base (little endian).
        :param digits: the digits (string with '?' representing unknown or list with '?'/None representing unknown)
        :param base: the base, must be a power of two and less than or equal to 36
        :return: a PartialInteger with known and unknown components as indicated by the digits
        """
        assert (base & (base - 1)) == 0, "Base must be power of two."
        assert base <= 36
        bits_per_element = int(log2(base))
        p = PartialInteger()
        rc_k = 0
        rc_u = 0
        value = 0
        for digit in digits:
            if digit is None or digit == '?':
                if rc_k > 0:
                    p.add_known(value, rc_k * bits_per_element)
                    rc_k = 0
                    value = 0
                rc_u += 1
            else:
                if isinstance(digit, str):
                    digit = int(digit, base)
                assert 0 <= digit < base
                if rc_u > 0:
                    p.add_unknown(rc_u * bits_per_element)
                    rc_u = 0
                value += digit * base ** rc_k
                rc_k += 1

        if rc_k > 0:
            p.add_known(value, rc_k * bits_per_element)

        if rc_u > 0:
            p.add_unknown(rc_u * bits_per_element)

        return p

    @staticmethod
    def parse_be(digits, base):
        """
        Constructs a PartialInteger from arbitrary digits in a provided base (big endian).
        :param digits: the digits (string with '?' representing unknown or list with '?'/None representing unknown)
        :param base: the base (must be a power of two and less than or equal to 36)
        :return: a PartialInteger with known and unknown components as indicated by the digits
        """
        return PartialInteger.parse_le(reversed(digits), base)

    @staticmethod
    def from_bits_le(bits):
        """
        Constructs a PartialInteger from bits (little endian).
        :param bits: the bits (string with '?' representing unknown or list with '?'/None representing unknown)
        :return: a PartialInteger with known and unknown components as indicated by the bits
        """
        return PartialInteger.parse_le(bits, 2)

    @staticmethod
    def from_bits_be(bits):
        """
        Constructs a PartialInteger from bits (big endian).
        :param bits: the bits (string with '?' representing unknown or list with '?'/None representing unknown)
        :return: a PartialInteger with known and unknown components as indicated by the bits
        """
        return PartialInteger.from_bits_le(reversed(bits))

    @staticmethod
    def from_hex_le(hex):
        """
        Constructs a PartialInteger from hex characters (little endian).
        :param hex: the hex characters (string with '?' representing unknown or list with '?'/None representing unknown)
        :return: a PartialInteger with known and unknown components as indicated by the hex characters
        """
        return PartialInteger.parse_le(hex, 16)

    @staticmethod
    def from_hex_be(hex):
        """
        Constructs a PartialInteger from hex characters (big endian).
        :param hex: the hex characters (string with '?' representing unknown or list with '?'/None representing unknown)
        :return: a PartialInteger with known and unknown components as indicated by the hex characters
        """
        return PartialInteger.from_hex_le(reversed(hex))

    @staticmethod
    def from_lsb(bit_length, lsb, lsb_bit_length):
        """
        Constructs a PartialInteger from some known lsb, setting the msb to unknown.
        :param bit_length: the total bit length of the integer
        :param lsb: the known lsb
        :param lsb_bit_length: the bit length of the known lsb
        :return: a PartialInteger with one known component (the lsb) and one unknown component (the msb)
        """
        assert bit_length >= lsb_bit_length
        assert 0 <= lsb <= (2 ** lsb_bit_length)
        return PartialInteger().add_known(lsb, lsb_bit_length).add_unknown(bit_length - lsb_bit_length)

    @staticmethod
    def from_msb(bit_length, msb, msb_bit_length):
        """
        Constructs a PartialInteger from some known msb, setting the lsb to unknown.
        :param bit_length: the total bit length of the integer
        :param msb: the known msb
        :param msb_bit_length: the bit length of the known msb
        :return: a PartialInteger with one known component (the msb) and one unknown component (the lsb)
        """
        assert bit_length >= msb_bit_length
        assert 0 <= msb < (2 ** msb_bit_length)
        return PartialInteger().add_unknown(bit_length - msb_bit_length).add_known(msb, msb_bit_length)

    @staticmethod
    def from_lsb_and_msb(bit_length, lsb, lsb_bit_length, msb, msb_bit_length):
        """
        Constructs a PartialInteger from some known lsb and msb, setting the middle bits to unknown.
        :param bit_length: the total bit length of the integer
        :param lsb: the known lsb
        :param lsb_bit_length: the bit length of the known lsb
        :param msb: the known msb
        :param msb_bit_length: the bit length of the known msb
        :return: a PartialInteger with two known components (the lsb and msb) and one unknown component (the middle bits)
        """
        assert bit_length >= lsb_bit_length + msb_bit_length
        assert 0 <= lsb < (2 ** lsb_bit_length)
        assert 0 <= msb < (2 ** msb_bit_length)
        middle_bit_length = bit_length - lsb_bit_length - msb_bit_length
        return PartialInteger().add_known(lsb, lsb_bit_length).add_unknown(middle_bit_length).add_known(msb, msb_bit_length)

    @staticmethod
    def from_middle(middle, middle_bit_length, lsb_bit_length, msb_bit_length):
        """
        Constructs a PartialInteger from some known middle bits, setting the lsb and msb to unknown.
        :param middle: the known middle bits
        :param middle_bit_length: the bit length of the known middle bits
        :param lsb_bit_length: the bit length of the unknown lsb
        :param msb_bit_length: the bit length of the unknown msb
        :return: a PartialInteger with one known component (the middle bits) and two unknown components (the lsb and msb)
        """
        assert 0 <= middle < (2 ** middle_bit_length)
        return PartialInteger().add_unknown(lsb_bit_length).add_known(middle, middle_bit_length).add_unknown(msb_bit_length)

    @staticmethod
    def lsb_of(i, bit_length, lsb_bit_length):
        """
        Constructs a PartialInteger from the lsb of a known integer, setting the msb to unknown.
        Mainly used for testing purposes.
        :param i: the known integer
        :param bit_length: the total length of the known integer
        :param lsb_bit_length: the bit length of the known lsb
        :return: a PartialInteger with one known component (the lsb) and one unknown component (the msb)
        """
        lsb = i % (2 ** lsb_bit_length)
        return PartialInteger.from_lsb(bit_length, lsb, lsb_bit_length)

    @staticmethod
    def msb_of(i, bit_length, msb_bit_length):
        """
        Constructs a PartialInteger from the msb of a known integer, setting the lsb to unknown.
        Mainly used for testing purposes.
        :param i: the known integer
        :param bit_length: the total length of the known integer
        :param msb_bit_length: the bit length of the known msb
        :return: a PartialInteger with one known component (the msb) and one unknown component (the lsb)
        """
        msb = i >> (bit_length - msb_bit_length)
        return PartialInteger.from_msb(bit_length, msb, msb_bit_length)

    @staticmethod
    def lsb_and_msb_of(i, bit_length, lsb_bit_length, msb_bit_length):
        """
        Constructs a PartialInteger from the lsb and msb of a known integer, setting the middle bits to unknown.
        Mainly used for testing purposes.
        :param i: the known integer
        :param bit_length: the total length of the known integer
        :param lsb_bit_length: the bit length of the known lsb
        :param msb_bit_length: the bit length of the known msb
        :return: a PartialInteger with two known components (the lsb and msb) and one unknown component (the middle bits)
        """
        lsb = i % (2 ** lsb_bit_length)
        msb = i >> (bit_length - msb_bit_length)
        return PartialInteger.from_lsb_and_msb(bit_length, lsb, lsb_bit_length, msb, msb_bit_length)

    @staticmethod
    def middle_of(i, bit_length, lsb_bit_length, msb_bit_length):
        """
        Constructs a PartialInteger from the middle bits of a known integer, setting the lsb and msb to unknown.
        Mainly used for testing purposes.
        :param i: the known integer
        :param bit_length: the total length of the known integer
        :param lsb_bit_length: the bit length of the unknown lsb
        :param msb_bit_length: the bit length of the unknown msb
        :return: a PartialInteger with one known component (the middle bits) and two unknown components (the lsb and msb)
        """
        middle_bit_length = bit_length - lsb_bit_length - msb_bit_length
        middle = (i >> lsb_bit_length) % (2 ** middle_bit_length)
        return PartialInteger.from_middle(middle, middle_bit_length, lsb_bit_length, msb_bit_length)


# https://github.com/jvdsn/crypto-attacks/blob/master/attacks/hnp/lattice_attack.py
import os
import sys

from sage.all import QQ
from sage.all import ZZ
from sage.all import matrix
from sage.all import vector

def shortest_vectors(B):
    """
    Computes the shortest non-zero vectors in a lattice.
    :param B: the basis of the lattice
    :return: a generator generating the shortest non-zero vectors
    """
    #logging.debug(f"Computing shortest vectors in {B.nrows()} x {B.ncols()} matrix...")
    B = B.LLL()

    for row in B.rows():
        if not row.is_zero():
            yield row

def attack(a, b, m, X):
    """
    Solves the hidden number problem using an attack based on the shortest vector problem.
    The hidden number problem is defined as finding y such that {xi = {aij * yj} + bi mod m}.
    :param a: the aij values
    :param b: the bi values
    :param m: the modulus
    :param X: a bound on the xi values
    :return: a generator generating tuples containing a list of xi values and a list of yj values
    """
    assert len(a) == len(b), "a and b lists should be of equal length."

    n1 = len(a)
    n2 = len(a[0])
    B = matrix(QQ, n1 + n2 + 1, n1 + n2 + 1)
    for i in range(n1):
        for j in range(n2):
            B[n1 + j, i] = a[i][j]

        B[i, i] = m
        B[n1 + n2, i] = b[i] - X // 2

    for j in range(n2):
        B[n1 + j, n1 + j] = X / QQ(m)

    B[n1 + n2, n1 + n2] = X

    for v in shortest_vectors(B):
        xs = [int(v[i] + X // 2) for i in range(n1)]
        ys = [(int(v[n1 + j] * m) // X) % m for j in range(n2)]
        if all(y != 0 for y in ys) and v[n1 + n2] == X:
            yield xs, ys


def dsa_known_msb(n, h, r, s, k):
    """
    Recovers the (EC)DSA private key and nonces if the most significant nonce bits are known.
    :param n: the modulus
    :param h: a list containing the hashed messages
    :param r: a list containing the r values
    :param s: a list containing the s values
    :param k: a list containing the partial nonces (PartialIntegers)
    :return: a generator generating tuples containing the possible private key and a list of nonces
    """
    assert len(h) == len(r) == len(s) == len(k), "h, r, s, and k lists should be of equal length."
    a = []
    b = []
    X = 0
    for hi, ri, si, ki in zip(h, r, s, k):
        msb, msb_bit_length = ki.get_known_msb()
        shift = 2 ** ki.get_unknown_lsb()
        a.append([(pow(si, -1, n) * ri) % n])
        b.append((pow(si, -1, n) * hi - shift * msb) % n)
        X = max(X, shift)

    for k_, x in attack(a, b, n, X):
        yield x[0], [ki.sub([ki_]) for ki, ki_ in zip(k, k_)]


def dsa_known_lsb(n, h, r, s, k):
    """
    Recovers the (EC)DSA private key and nonces if the least significant nonce bits are known.
    :param n: the modulus
    :param h: a list containing the hashed messages
    :param r: a list containing the r values
    :param s: a list containing the s values
    :param k: a list containing the partial nonces (PartialIntegers)
    :return: a generator generating tuples containing the possible private key and a list of nonces
    """
    assert len(h) == len(r) == len(s) == len(k), "h, r, s, and k lists should be of equal length."
    a = []
    b = []
    X = 0
    for hi, ri, si, ki in zip(h, r, s, k):
        lsb, lsb_bit_length = ki.get_known_lsb()
        inv_shift = pow(2 ** lsb_bit_length, -1, n)
        a.append([(inv_shift * pow(si, -1, n) * ri) % n])
        b.append((inv_shift * pow(si, -1, n) * hi - inv_shift * lsb) % n)
        X = max(X, 2 ** ki.get_unknown_msb())

    for k_, x in attack(a, b, n, X):
        nonces = [ki.sub([ki_]) for ki, ki_ in zip(k, k_)]
        yield x[0], nonces


def dsa_known_middle(n, h1, r1, s1, k1, h2, r2, s2, k2):
    """
    Recovers the (EC)DSA private key and nonces if the middle nonce bits are known.
    This is a heuristic extension which might perform worse than the methods to solve the Extended Hidden Number Problem.
    More information: De Micheli G., Heninger N., "Recovering cryptographic keys from partial information, by example" (Section 5.2.3)
    :param n: the modulus
    :param h1: the first hashed message
    :param r1: the first r value
    :param s1: the first s value
    :param k1: the first partial nonce (PartialInteger)
    :param h2: the second hashed message
    :param r2: the second r value
    :param s2: the second s value
    :param k2: the second partial nonce (PartialInteger)
    :return: a tuple containing the private key, the nonce of the first signature, and the nonce of the second signature
    """
    k_bit_length = k1.bit_length
    assert k_bit_length == k2.bit_length
    lsb_unknown = k1.get_unknown_lsb()
    assert lsb_unknown == k2.get_unknown_lsb()
    msb_unknown = k1.get_unknown_msb()
    assert msb_unknown == k2.get_unknown_msb()
    K = 2 ** max(lsb_unknown, msb_unknown)
    l = k_bit_length - msb_unknown

    a1 = k1.get_known_middle()[0] << lsb_unknown
    a2 = k2.get_known_middle()[0] << lsb_unknown
    t = -(pow(s1, -1, n) * s2 * r1 * pow(r2, -1, n))
    u = pow(s1, -1, n) * r1 * h2 * pow(r2, -1, n) - pow(s1, -1, n) * h1
    u_ = a1 + t * a2 + u

    B = matrix(ZZ, 5, 5)
    B[0] = vector(ZZ, [K, K * 2 ** l, K * t, K * t * 2 ** l, u_])
    B[1] = vector(ZZ, [0, K * n, 0, 0, 0])
    B[2] = vector(ZZ, [0, 0, K * n, 0, 0])
    B[3] = vector(ZZ, [0, 0, 0, K * n, 0])
    B[4] = vector(ZZ, [0, 0, 0, 0, n])

    A = matrix(ZZ, 4, 4)
    b = []
    for row, v in enumerate(shortest_vectors(B)):
        A[row] = v[:4].apply_map(lambda x: x // K)
        b.append(-v[4])
        if row == A.nrows() - 1:
            break

    assert len(b) == 4
    x1, y1, x2, y2 = A.solve_right(vector(ZZ, b))
    assert (x1 + 2 ** l * y1 + t * x2 + 2 ** l * t * y2 + u_) % n == 0

    k1 = k1.sub([int(x1), int(y1)])
    k2 = k2.sub([int(x2), int(y2)])
    private_key1 = (pow(r1, -1, n) * (s1 * k1 - h1)) % n
    private_key2 = (pow(r2, -1, n) * (s2 * k2 - h2)) % n
    assert private_key1 == private_key2
    return int(private_key1), int(k1), int(k2)

# Santa's Signature parameters
message = b"""
We're no strangers to love
You know the rules and so do I (do I)
A full commitment's what I'm thinking of
You wouldn't get this from any other guy

I just wanna tell you how I'm feeling
Gotta make you understand

Never gonna give you up
Never gonna run around
Never gonna make you cry
Never gonna tell a lie and hurt you

We've known each other for so long
Your heart's been aching, but you're too shy to say it (say it)
Inside, we both know what's been going on
We know the game, and we're gonna play it

And if you ask me how I'm feeling
Don't tell me you're too blind to see

Never gonna give you up
Never gonna let you down
Never gonna run around and desert you
Never gonna make you cry
Never gonna say goodbye
Never gonna tell a lie and hurt you

Never gonna give, never gonna give
Never gonna give, never gonna give
Never gonna give, never gonna give (give you up)
(Ooh) never gonna give, never gonna give (give you up)

I just wanna tell you how I'm feeling
Gotta make you understand
Never gonna, never gonna

Never gonna give you up
Never gonna let you down
Never gonna run around and desert you
Never gonna make you cry
Never gonna say goodbye
Never gonna tell a lie and hurt you

Never gonna give you up
Never gonna let you down
Never gonna run around and desert you
Never gonna make you cry
Never gonna say goodbye
Never gonna tell a lie and hurt you
"""

flag_prefix = "HV24{"
flag_length = 20

# NIST192p.order
n = 6277101735386680763835789423176059013767194773182842284081

h_val_from_message = int.from_bytes(sha256(message).digest(), "big") % n
h_val = 5451444470609933768673875739190099258978652043860043513059
assert(h_val == h_val_from_message)
h = [h_val] * 4

# From output.txt
r = [382825619053484650723101111089716481637169498894438388011, 2846338329314931410625679965921020604974471932472870479272, 4539748290341241446856454569550628724992441965649378727404, 941682904620798018129415714406121176743478727872983123639]
s = [1053747182506109288607080233885972025033725041930583121945, 271361922488295908863717359631373504169617539839833749415, 1147747170412930491481269098330085803226817442551773675299, 3831443458083168767818771718543562148023158622090413416724]

charset = string.ascii_letters + string.digits + string.punctuation

for char in charset:
    found = False
    
    # Construct k_candidate
    k_candidate = flag_prefix + char
    print(k_candidate)
    
    # Partial integer with 48 known bits / 144 unknown bits
    k_pi = PartialInteger()
    k_pi.add_unknown(144)
    
    k_value = int.from_bytes(k_candidate.encode(), "big")
    k_pi.add_known(k_value, 48)
    
    k = [k_pi ] * 4
    
    # Info
    known_msb, msb_bit_length = k_pi.get_known_msb()
    known_lsb, lsb_bit_length = k_pi.get_known_lsb()
    unknown_msb = k_pi.get_unknown_msb()
    unknown_lsb = k_pi.get_unknown_lsb()

    print(f"Known MSB: {bin(known_msb)}, Known MSB bit length: {msb_bit_length}")
    print(f"Known LSB: {bin(known_lsb)}, Known LSB bit length: {lsb_bit_length}")
    print(f"Unknown MSB bit length: {unknown_msb}")
    print(f"Unknown LSB bit length: {unknown_lsb}")
    print(f"Total bit length: {k_pi.bit_length}")
    
    # Attack
    attempt = dsa_known_msb(n, h, r, s, k)
    for (private_key, k_list) in attempt:
        
        print(f"private_key={private_key}")
        
        # Validate answer
        for k in k_list:
            bit_length = log(k, 2).n()
            total_bytes = ceil(bit_length + 7) // 8
            k_bytes = int(k).to_bytes(total_bytes, "big")
            
            print(f"k={k}")
            
            flag_candidate = k_bytes[:flag_length]
            
            # Ensure all of flags byte are in expected charset
            if all(chr(byte) in charset for byte in flag_candidate):
                print(f"Recovered flag: {flag_candidate}")
                exit()

print("No result found.")

This script takes the known parameters from the original santas-signature.py and output.txt for n, h, r and s. We then make an educated guess that the flag's inner contents are within the string.ascii_letters + string.digits + string.punctuation charset and iterate over this charset to generate the extra byte to add after our known prefix, which is HV24{. Our guesses will be something like HV24{a, HV24{b and so on. We configure the PartialInteger to contain 48 known MSB bits and 144 unknown LSB bits. Note that this implies the total length will be 192 bits. For each attack, if we get a successful result that gives us a private_key and k_list, we validate the result by checking each k value and ensuring the first 20 bytes are all in the charset that we expect the flag's contents to be in. If this is the case, we likely have found the original flag.

Running this script gives us many matches, but only the attempt with the HV24{j prefix successfully passes our charset check:

HV24{j
<__main__.PartialInteger object at 0x7f0b098feff0>
<__main__.PartialInteger object at 0x7f0b098feff0>
Known MSB: 0b10010000101011000110010001101000111101101101010, Known MSB bit length: 48
Known LSB: 0b0, Known LSB bit length: 0
Unknown MSB bit length: 0
Unknown LSB bit length: 144
Total bit length: 192
private_key=2590803528145236545355230583815424784952960266209950287878
k=1773690810533281069433785461577691874864508299683322378031
Recovered flag: b'HV24{just_us3_EdDSA}'
k=1773690810533281069433785461577691874864508299686215035367
Recovered flag: b'HV24{just_us3_EdDSA}'
k=1773690810533281069433785461577691874864508299683488554507
Recovered flag: b'HV24{just_us3_EdDSA}'
k=1773690810533281069433785461577691874864508299683150737252
Recovered flag: b'HV24{just_us3_EdDSA}'

We now know the private_key was 2590803528145236545355230583815424784952960266209950287878 and the flag in the k was HV24{just_us3_EdDSA}.

What if k were shorter?

In our case the total flag length was exactly 20 bytes and the total bits in k was exactly 192 bits. However, the flag could have been shorter (between 7 and 20 bytes in total). If this were the case, because we used the MSB to encode the flag, there would be additional null (\x00) bytes at the start of the k nonce. Therefore, if we didn’t find a solution for a 20-character flag, we would modify our script to test for a 19-character flag. This would mean using a k_candidate such as b'\x00' + b'HV24{'. Notice that we no longer need to brute-force the extra character after HV24{ because the new null byte already provides the 48 bits of known MSB needed. If this still didn’t yield a solution, we would then test for an 18-character flag with a k_candidate that includes two null bytes, like b'\x00\x00' + b'HV24{', and continue reducing the flag length by one character until we try to solve for a 7-character flag.


Flag:

HV24{just_us3_EdDSA}

Leave a comment

(required)(will not be published)(required)


Comments

There are no comments yet. Be the first to add one!