2021: day21: part 2

This commit is contained in:
Antoine Martin 2021-12-21 14:40:45 +01:00
parent c92553fa40
commit a5fd485c10

View file

@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::fmt::Write;
use std::iter;
use std::ops::RangeInclusive;
@ -10,6 +11,7 @@ pub fn run() -> Result<String> {
let mut res = String::with_capacity(128);
writeln!(res, "part 1: {}", part1(INPUT)?)?;
writeln!(res, "part 2: {}", part2(INPUT)?)?;
Ok(res)
}
@ -48,6 +50,68 @@ fn part1(input: &str) -> Result<usize> {
Ok(loser_score * dice.rolls())
}
fn part2(input: &str) -> Result<usize> {
let mut lines = input.lines();
let player1_pos: PlayerPos = lines
.next()
.and_then(|line| line.trim().strip_prefix("Player 1 starting position: "))
.and_then(|pos| pos.parse().ok())
.map(PlayerPos::new)
.context("couldn't find player 1 pos")?;
let player2_pos: PlayerPos = lines
.next()
.and_then(|line| line.trim().strip_prefix("Player 2 starting position: "))
.and_then(|pos| pos.parse().ok())
.map(PlayerPos::new)
.context("couldn't find player 2 pos")?;
let (player1_score, player2_score) =
quantum_dice_game(player1_pos, player2_pos, 0, 0, &mut HashMap::new());
Ok(player1_score.max(player2_score))
}
type Cache = HashMap<(PlayerPos, PlayerPos, usize, usize), (usize, usize)>;
fn quantum_dice_game(
pos1: PlayerPos,
pos2: PlayerPos,
score1: usize,
score2: usize,
cache: &mut Cache,
) -> (usize, usize) {
// We swap players on each recursive call, so player 2 is the previous player 1. Player 1 is the
// only one who played, so we only need to check his score.
if score2 >= 21 {
return (0, 1);
}
// Memoization
if let Some(wins) = cache.get(&(pos1, pos2, score1, score2)) {
return *wins;
}
let (mut wins1, mut wins2) = (0, 0);
// 3 = 1 + 1 + 1
// 4 = 1 + 1 + 2, 1 + 2 + 1, 2 + 1 + 1
// ...
// 9 = 3 + 3 + 3
for (mv, times) in [(3, 1), (4, 3), (5, 6), (6, 7), (7, 6), (8, 3), (9, 1)] {
let mut pos1 = pos1; // copy
pos1.advance_by(mv);
// We swap out player 1 and 2 for the next recursion
let (w2, w1) = quantum_dice_game(pos2, pos1, score2, score1 + pos1.pos(), cache);
wins1 += w1 * times;
wins2 += w2 * times;
}
cache.insert((pos1, pos2, score1, score2), (wins1, wins2));
(wins1, wins2)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct PlayerPos(usize);
impl PlayerPos {
@ -106,4 +170,14 @@ mod tests {
fn part1_real() {
assert_eq!(part1(INPUT).unwrap(), 908595);
}
#[test]
fn part2_provided() {
assert_eq!(part2(PROVIDED).unwrap(), 444356092776315);
}
#[test]
fn part2_real() {
assert_eq!(part2(INPUT).unwrap(), 91559198282731);
}
}