rewritten using recursion and function output caching

This commit is contained in:
Andreas 2024-12-11 10:57:34 +01:00 committed by Andreas
parent 8f116fedab
commit c442207f20

View file

@ -1,34 +1,50 @@
from tqdm import tqdm
from argparse import ArgumentParser
import sys
def main(input, iterations):
input_list: list = input.split(" ")
for _ in tqdm(range(iterations)):
new_stones = []
for num in input_list:
if num == '0':
new_stones.extend(["1"])
elif len(num) % 2 == 0:
half = len(num) // 2
left_half = str(int(num[:half]))
right_half = str(int(num[half:]))
new_stones.extend([left_half, right_half])
else:
new_stones.extend([str(int(num)*2024)])
input_list = new_stones
return len(input_list)
sys.setrecursionlimit(10**7)
def transform(num_str: str) -> list:
# Apply one iteration of the rules
if num_str == '0':
return ["1"]
elif len(num_str) % 2 == 0:
half = len(num_str) // 2
return [str(int(num_str[:half])), str(int(num_str[half:]))]
else:
return [str(int(num_str) * 2024)]
def expand(num: str, iteration: int, cache: dict) -> int:
# Returns count of stones after 'iteration' blinks
if iteration == 0:
return 1
key = (num, iteration)
if key in cache:
return cache[key]
nxt = transform(num)
total = 0
for n in nxt:
total += expand(n, iteration - 1, cache)
cache[key] = total
return total
def main(input_str: str, iterations: int) -> int:
stones = input_str.split(" ")
cache = {}
count = 0
with tqdm(total=len(stones), desc="Expanding stones") as pbar:
for s in stones:
count += expand(s, iterations, cache)
pbar.update(1)
return count
if __name__ == "__main__":
args = ArgumentParser(
prog="Advend of Code Day 11",
description="Solves the puzzle for day 11"
)
args.add_argument("-p", "--part")
args.add_argument("-f", "--file")
args = args.parse_args()
input_text = open(args.file, 'r').read()
iterations = 25 if args.part == "1" else 75 if args.part == "2" else None
answer = main(input_text, iterations)
print(str(answer))
parser = ArgumentParser()
parser.add_argument("-p", "--part")
parser.add_argument("-f", "--file")
args = parser.parse_args()
inp = open(args.file, 'r').read().strip()
iters = 25 if args.part == "1" else (75 if args.part == "2" else None)
print(main(inp, iters))