rewritten using recursion and function output caching

This commit is contained in:
Andreas 2024-12-11 10:57:34 +01:00 committed by Andreas
parent 8f116fedab
commit c442207f20

View file

@ -1,34 +1,50 @@
from tqdm import tqdm from tqdm import tqdm
from argparse import ArgumentParser from argparse import ArgumentParser
import sys
def main(input, iterations): sys.setrecursionlimit(10**7)
input_list: list = input.split(" ")
for _ in tqdm(range(iterations)):
new_stones = []
for num in input_list:
if num == '0':
new_stones.extend(["1"])
elif len(num) % 2 == 0:
half = len(num) // 2
left_half = str(int(num[:half]))
right_half = str(int(num[half:]))
new_stones.extend([left_half, right_half]) def transform(num_str: str) -> list:
# Apply one iteration of the rules
if num_str == '0':
return ["1"]
elif len(num_str) % 2 == 0:
half = len(num_str) // 2
return [str(int(num_str[:half])), str(int(num_str[half:]))]
else: else:
new_stones.extend([str(int(num)*2024)]) return [str(int(num_str) * 2024)]
input_list = new_stones
return len(input_list)
def expand(num: str, iteration: int, cache: dict) -> int:
# Returns count of stones after 'iteration' blinks
if iteration == 0:
return 1
key = (num, iteration)
if key in cache:
return cache[key]
nxt = transform(num)
total = 0
for n in nxt:
total += expand(n, iteration - 1, cache)
cache[key] = total
return total
def main(input_str: str, iterations: int) -> int:
stones = input_str.split(" ")
cache = {}
count = 0
with tqdm(total=len(stones), desc="Expanding stones") as pbar:
for s in stones:
count += expand(s, iterations, cache)
pbar.update(1)
return count
if __name__ == "__main__": if __name__ == "__main__":
args = ArgumentParser( parser = ArgumentParser()
prog="Advend of Code Day 11", parser.add_argument("-p", "--part")
description="Solves the puzzle for day 11" parser.add_argument("-f", "--file")
) args = parser.parse_args()
args.add_argument("-p", "--part")
args.add_argument("-f", "--file") inp = open(args.file, 'r').read().strip()
args = args.parse_args() iters = 25 if args.part == "1" else (75 if args.part == "2" else None)
input_text = open(args.file, 'r').read()
iterations = 25 if args.part == "1" else 75 if args.part == "2" else None print(main(inp, iters))
answer = main(input_text, iterations)
print(str(answer))