Hat puzzle

How to solve the Hat puzzle programmatically.


A. Coady


January 11, 2020

How to solve the Hat puzzle programmatically.

Ten-Hat Variant

In this variant there are 10 prisoners and 10 hats. Each prisoner is assigned a random hat, either red or blue, but the number of each color hat is not known to the prisoners. The prisoners will be lined up single file where each can see the hats in front of him but not behind. Starting with the prisoner in the back of the line and moving forward, they must each, in turn, say only one word which must be “red” or “blue”. If the word matches their hat color they are released, if not, they are killed on the spot. A friendly guard warns them of this test one hour beforehand and tells them that they can formulate a plan where by following the stated rules, 9 of the 10 prisoners will definitely survive, and 1 has a 50/50 chance of survival. What is the plan to achieve the goal?

This puzzle involves three concepts common to classic logic puzzles:

Theory of mind comes into play because each prisoner has differing knowledge, but assumes everyone else will think similarly. Functional fixedness occurs more subtly; each prisoner may state a color only to convey information. But because the information is encoded as a color, it tends to focus thinking on the colors themselves. So to combat that cognitive bias, first create a different enumeration to represent statements. Any binary enum can be mapped back to colors, so why not bool.

colors = 'red', 'blue'
colors[False], colors[True]
('red', 'blue')

Which leaves induction: solve the puzzle for the base case (smallest size) first, and then methodically build on that solution. In the case of 1 prisoner, they have no information a priori, and therefore have a 50/50 chance of survival regardless of strategy. This variant of the puzzle already gives the optimal goal, so we know that everyone but the 1st can say their color and be saved, while the 1st can devote their answer to the common cause.

In the case of 2 prisoners, obviously the 1st can say the color of the 2nd. That approach does not scale; it is the path to functional fixedness. Instead, methodically enumerate all possible statements and colors to determine if there is an unambiguous solution.

table = list(zip([False, True], colors))
[(False, 'red'), (True, 'blue')]

The above table is a general solution with no assumptions other than the arbitrary ordering of enums. While it may appear absurdly pedantic, it represents a rule set which is key to building a recursive solution.

In the case of the 3rd prisoner, clearly they can not just repeat the above rule set, because the 3rd would receive no information. But there are only 2 choices, so the only option is to follow the opposite rule set, depending on the 3rd color.

The crucial step is to build off of the existing table.

table = [row + colors[:1] for row in table] + [(not row[0],) + row[1:] + colors[1:] for row in table]
[(False, 'red', 'red'),
 (True, 'blue', 'red'),
 (True, 'red', 'blue'),
 (False, 'blue', 'blue')]

The solution is valid if each prisoner is able to narrow the possibilities to a unique row based on the colors they hear and see.

import collections

def test(table):
    """Assert that the input table is a valid solution."""
    (size,) = set(map(len, table))
    for index in range(size):
        counts = collections.Counter(row[:index] + row[index + 1:] for row in table)
        assert set(counts.values()) == {1}


The general solution is simply the above logic in recursive form, with a parametrized size.

def solve(count: int):
    """Generate a flat table of all spoken possibilities."""
    if count <= 1:
        yield False,
    for row in solve(count - 1):
        yield row + colors[:1]
        yield (not row[0],) + row[1:] + colors[1:]

[(False, 'red', 'red'),
 (True, 'red', 'blue'),
 (True, 'blue', 'red'),
 (False, 'blue', 'blue')]

The complicated puzzle is actually a trivial recurrence relation: \[ 2^n = 2^{n-1} * 2 \] There are \(2^n\) states of the prisoners, and each prisoner has \(n-1\) bits of data. So an additional bit of data from the first is sufficient to solve the puzzle.

table = list(solve(10))
[(False, 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'red'),
 (True, 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'blue'),
 (True, 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'blue', 'red')]

The puzzle is solved, but the output is of exponential size, certainly not the succinct solution which makes the puzzle famous. But instead of relying on a flash of insight, this approach produces not just a solution, but the solution. The only arbitrary decision made was the enumeration. Therefore it must be the case that the solution can be summarized.

First, it would be helpful to group the solution by the 1st statement. Any summary function would have to ensure that there is no collision in the grouped possibilities.

groups = collections.defaultdict(set)
for row in table:
groups = groups[False], groups[True]

def summarize(func, groups):
    """Apply summary function to groups and assert uniqueness."""
    groups = tuple(set(map(func, group)) for group in groups)
    assert set.isdisjoint(*groups)
    return groups

assert summarize(lambda g: g, groups) == groups
tuple(map(len, groups))
(256, 256)

Now what summaries to attempt? Well there are few properties of sequences to work with: size and order. They are all the same size, so that won’t help. That leaves ordering, which can be easily tested by sorting.

summarize(lambda g: tuple(sorted(g)), groups)
({('blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'red'),
  ('blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'red', 'red', 'red'),
  ('blue', 'blue', 'blue', 'blue', 'red', 'red', 'red', 'red', 'red'),
  ('blue', 'blue', 'red', 'red', 'red', 'red', 'red', 'red', 'red'),
  ('red', 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'red')},
 {('blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'blue'),
  ('blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'blue', 'red', 'red'),
  ('blue', 'blue', 'blue', 'blue', 'blue', 'red', 'red', 'red', 'red'),
  ('blue', 'blue', 'blue', 'red', 'red', 'red', 'red', 'red', 'red'),
  ('blue', 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'red')})

Success. Now that order does not matter, the appropriate data structure is a multiset (a.k.a. bag). Each prisoner can keep track of only how many of each color they hear and see.

summarize(lambda g: frozenset(collections.Counter(g).items()), groups)
({frozenset({('blue', 8), ('red', 1)}),
  frozenset({('blue', 2), ('red', 7)}),
  frozenset({('blue', 6), ('red', 3)}),
  frozenset({('blue', 4), ('red', 5)}),
  frozenset({('red', 9)})},
 {frozenset({('blue', 1), ('red', 8)}),
  frozenset({('blue', 7), ('red', 2)}),
  frozenset({('blue', 5), ('red', 4)}),
  frozenset({('blue', 3), ('red', 6)}),
  frozenset({('blue', 9)})})

Since there are only 2 colors which sum to a constant, keeping track of just one is sufficient.

summarize(lambda g: g.count(colors[0]), groups)
({1, 3, 5, 7, 9}, {0, 2, 4, 6, 8})

There’s one last pattern to the numbers, which can be used to achieve parity with the canonical solution.