-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
144 lines (129 loc) · 5.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from datetime import datetime
import sys
from arc_tools import grid
from arc_tools.logger import logger
from copy import deepcopy
from itertools import combinations
import os
import json
from arc_tools.grid import Grid, GridPoint, GridRegion, detect_objects
from arc_tools.plot import plot_grid, plot_grids, remove_pngs
from arc_tools.squash import squash_grid
from train_tasks import *
from evaluation_tasks import *
show_count = 0
from collections import Counter, deque # Add deque import
from arc_tools.grid import SubGrid # Add SubGrid import
from typing import Sequence # Add typing imports
from arc_tools.grid import Color
from arc_tools.grid import move_object
remove_pngs()
if 0:
normal_task_fns = [
count_hollows_task,
check_fit,
move_object_without_collision,
repeat_reverse_grid,
]
else:
normal_task_fns = [
# row_col_color_data,
# color_swap_and_move_to_corner,
# dot_to_object,
# rope_stretch,
fit_or_swap_fit
]
jigsaw_task_fns = [
# jigsaw_puzzle,
row_col_color_data, # can occur in normal task
]
def debug_output(grid, expected_output, output):
# print which cells are different
for i in range(len(expected_output)):
for j in range(len(expected_output[0])):
if expected_output[i][j] != output[i][j]:
print(f"Cell {i}, {j} is different")
plot_grids([grid, expected_output, output], show=1)
def find_task(grids, expected_outputs, start_train_task_id=1):
if len(grids[0][0]) == len(expected_outputs[0][0]):
task_fns = normal_task_fns
else:
task_fns = jigsaw_task_fns
if actual_task_name:
task_fns = [globals()[actual_task_name]]
for task_fn in task_fns:
logger.info(task_fn.__name__)
right_task = True
for task_id, (grid, expected_output) in enumerate(zip(grids, expected_outputs), start_train_task_id):
expected_output = Grid(expected_output)
plot_grid(expected_output, name="expected_output.png")
plot_grid(grid, name="input.png")
output = task_fn(grid)
plot_grid(output, name="actual_output.png")
if not output.compare(expected_output):
# debug_output(grid, expected_output, output)
if actual_task_name:
logger.info(f'Train task {task_id} failed')
right_task = False
break
logger.info(f'Train task {task_id} passed')
if right_task:
return task_fn
logger.info('--------------------------------')
return None
def solve_task(data):
num_train_tasks = len(data['train'])
num_test_tasks = len(data['test'])
logger.info(f"Number of train tasks: {num_train_tasks}, Number of test tasks: {num_test_tasks}")
start_train_task_id = 1
start_test_task_id = 1
actual_task_name = None
start_train_task_id = 2
# start_test_task_id = 2
grids = []
expected_outputs = []
actual_outputs = []
with open('reference_output.json', 'w') as f:
json.dump(data['train'][0]['output'], f)
if not actual_task_name:
for task_idx in range(start_train_task_id - 1, num_train_tasks):
grids.append(Grid(data['train'][task_idx]['input']))
expected_outputs.append(data['train'][task_idx]['output'])
task_fn = find_task(grids, expected_outputs, start_train_task_id)
if task_fn:
logger.info(f"Found task: {task_fn.__name__}")
else:
logger.info(f"Task not found")
else:
task_fn = globals()[actual_task_name]
for task_idx in range(start_test_task_id - 1, num_test_tasks):
grid = Grid(data['test'][task_idx]['input'])
if task_fn:
plot_grid(grid, name="input.png", show=0)
expected_output = Grid(data['test'][task_idx].get('output'))
plot_grid(expected_output, name="expected_output.png")
output = task_fn(grid)
plot_grid(output, name="actual_output.png")
if expected_output:
if output.compare(expected_output):
logger.info(f"Test task {task_idx + 1} passed")
else:
logger.info(f"Test task {task_idx + 1} failed")
raise Exception(f"Incorrect task {task_idx + 1}: {task_fn.__name__}, Expected: {expected_output}, Actual: {output}")
output = {"attempt_1": output, "attempt_2": output}
else:
output = {"attempt_1": grid, "attempt_2": grid}
actual_outputs.append(output)
return actual_outputs
if __name__ == "__main__":
task_hash = 'abc82100'
if sys.argv[1:]:
task_hash = sys.argv[1]
actual_task_name = sys.argv[2] if sys.argv[2:] else None
split = ['evaluation', 'training']
for s in split:
file = rf'../ARC-AGI-2/data/{s}/{task_hash}.json'
if os.path.exists(file):
break
data = json.load(open(file, 'r'))
solve_task(data)