-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmllm_state.py
152 lines (121 loc) · 5.83 KB
/
mllm_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from mllm_types import EnvConfig, GenerationOutput, GenerationEnv
from pathlib import Path
import subprocess, tempfile, hashlib
import os
import shutil
import yaml
import reactivedb
import argparse
from mllm_task_file import parse_text, unparse_items
class MergeConflictError(Exception):
def __init__(self, conflict_files):
self.conflict_files = conflict_files
def safe_child(root_path, path):
abs_path = (root_path / path).resolve()
if not str(abs_path).startswith(str(root_path) + '/'):
raise Exception('path outside of root path (%r -> %r)' % (abs_path, path))
return abs_path
def generate_diff(env_config: EnvConfig, tip: str, base: str, changed_files: list[str], html: bool) -> dict[str, str]:
result = {}
for file in changed_files:
cmd = ['patdiff', '-html', '-default', '-warn-if-no-trailing-newline-in-both', 'false'] if html else ['diff', '-u']
tip_file = safe_child(env_config.state_root / tip, file)
base_file = safe_child(env_config.state_root / base, file)
result[file] = subprocess.run(cmd + [base_file, tip_file], capture_output=True).stdout.decode('utf8', 'replace')
return result
def generate_diff_for_output(env_config: EnvConfig, output: GenerationOutput, html: bool):
return generate_diff(env_config, output.tip_revision, output.base_revision,
output.changed_files, html=html)
def merge3(mine: Path, older: Path, yours: Path, merge_result: Path,
force=False):
# returns true if merge successful
with open(merge_result, 'wb') as f:
#r = subprocess.call(['diff3', '-m', str(mine), str(older), str(yours)], stdout=f)
r = subprocess.call(['git', 'merge-file', '--stdout', '--diff3',
str(mine), str(older), str(yours)], stdout=f)
if r > 127:
raise Exception(f'git merge-file failed with unexpected exit code ({r})')
if force:
return True
else:
return r == 0
def hash_file(file):
hasher = hashlib.sha256()
chunk_size = 1024*1024
with open(file, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
hasher.update(chunk)
return hasher.hexdigest()
def can_merge(base: Path, tip: Path, live_path: Path) -> tuple[bool, str]:
# returns if can merge and live_path hash used
if not base.exists() and live_path.exists():
return True, ''
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
tmp_live_copy = temp_dir_path / "live_copy"
shutil.copy(live_path, tmp_live_copy)
ok = merge3(tmp_live_copy, base, tip, Path('/dev/null'))
hash = hash_file(tmp_live_copy)
return ok, hash
def apply_output(env_config: EnvConfig, root_path: Path, output: GenerationOutput, force=False):
# First, attempt to merge all files without writing to live files
merge_results = {}
conflict_files = []
new_files = []
for filepath in output.changed_files:
base_file = safe_child(env_config.state_root / output.base_revision, filepath)
tip_file = safe_child(env_config.state_root / output.tip_revision, filepath)
live_file = safe_child(root_path, filepath) # Live (working) version of the file
if not base_file.exists() and not live_file.exists():
new_files.append((filepath, tip_file))
continue
with tempfile.NamedTemporaryFile(delete=False) as temp_merge_result:
ok = merge3(live_file, base_file, tip_file, Path(temp_merge_result.name), force=force)
if ok:
merge_results[filepath] = temp_merge_result.name
else:
conflict_files.append(filepath)
os.unlink(temp_merge_result.name)
if conflict_files:
for filepath, merged_file_path in merge_results.items():
Path(merged_file_path).unlink(missing_ok=True)
raise MergeConflictError(conflict_files)
for filepath, merged_file_path in merge_results.items():
live_file = safe_child(root_path, filepath)
mode = os.stat(live_file).st_mode & 0o777
shutil.copy(merged_file_path, live_file)
os.chmod(live_file, mode)
os.unlink(merged_file_path) # Clean up the temporary merge result
for filepath, new_file in new_files:
live_file = safe_child(root_path, filepath)
live_file.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(new_file, live_file)
def cmd_apply_output(env_config_path: Path, generation_id: int):
root_path = env_config_path.parent.resolve()
with open(env_config_path, 'r') as f:
env_config_data = yaml.safe_load(f)
env_config = EnvConfig.parse_obj(env_config_data)
db = reactivedb.Db(env_config.state_root / "db.sqlite3")
table = db.table(GenerationOutput)
output = table.get(id=generation_id)
if not output:
raise Exception(f"GenerationOutput with id {generation_id} not found")
apply_output(env_config, root_path, output)
# Set applied to True and save to database
output.applied = True # New line added
table.set(output) # New line added
db.commit() # New line added
def mark_as_done(todo_path, task_id):
text = todo_path.read_text()
items = parse_text(text)
for it in items:
if it.task_id == task_id and it.content and not it.content[0].startswith("DONE "):
it.content[0] = "DONE " + it.content[0]
new_text = unparse_items(items)
todo_path.write_text(new_text)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Apply a GenerationOutput to the live codebase')
parser.add_argument('--env-config-path', type=Path, help='Path to env_config.yaml')
parser.add_argument('--generation-id', type=int, help='ID of GenerationOutput to apply')
args = parser.parse_args()
cmd_apply_output(args.env_config_path, args.generation_id)