Skip to content

Commit 9532ee5

Browse files
committed
add training scripts & solve dependency issue
1 parent dc1c856 commit 9532ee5

File tree

12 files changed

+1619
-6
lines changed

12 files changed

+1619
-6
lines changed

FeedbackPolicy/data/data.py

Lines changed: 954 additions & 0 deletions
Large diffs are not rendered by default.

FeedbackPolicy/models/factory.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import os
12
import torch
2-
import open_clip
33
from transformers import CLIPTextModel, CLIPTokenizer
44

55
from visual_planner.trainer import GoalGaussianDiffusion
@@ -11,6 +11,13 @@
1111
from .vit import VisionTransformer
1212

1313

14+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
15+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
16+
17+
IMAGENET_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
18+
IMAGENET_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
19+
20+
1421
def load_model(
1522
clip_vision_encoder_path: str,
1623
clip_vision_encoder_pretrained: str,
@@ -75,3 +82,52 @@ def load_model(
7582

7683

7784
return visual_planner, policy_model, tokenizer, text_encoder
85+
86+
87+
88+
def create_feedback_policy(
89+
vision_encoder: str = 'vc1-base', #TODO: Support additional visual encoders
90+
resume_from_checkpoint: str = None,
91+
):
92+
93+
import torchvision.transforms as transforms
94+
image_processor = transforms.Compose([
95+
transforms.Resize((192, 192), interpolation = transforms.InterpolationMode.BICUBIC),
96+
transforms.ToTensor(),
97+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
98+
])
99+
pretrained_model = "clip-vit-large-patch14"
100+
text_tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path = pretrained_model)
101+
102+
from vc_models.models.vit import model_utils
103+
vision_encoder = model_utils.load_model(model_utils.VC1_BASE_NAME)
104+
embd_size = 768
105+
106+
107+
model = FeedbackDrivenPolicy(vision_encoder = vision_encoder, \
108+
vis_dim = embd_size,
109+
window_size = 5,
110+
sampling_step = 1)
111+
112+
model.vision_encoder.requires_grad_(False)
113+
114+
def check_file_exists(file_path):
115+
if not os.path.isfile(file_path):
116+
raise FileNotFoundError(f"The file '{file_path}' does not exist.")
117+
118+
119+
print('Try loading from ckpt')
120+
try:
121+
check_file_exists(resume_from_checkpoint)
122+
old_ckpt = torch.load(resume_from_checkpoint)['model_state_dict']
123+
124+
# remove 'module.' in original keys
125+
new_ckpt = {}
126+
for k, v in old_ckpt.items():
127+
new_ckpt[k[7:]] = v
128+
model.load_state_dict(new_ckpt, strict=False)
129+
130+
except FileNotFoundError as e:
131+
print(e)
132+
133+
return model, image_processor, text_tokenizer

FeedbackPolicy/models/policy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from einops import rearrange
1010
from einops import repeat
1111

12-
from .transformer_utils import Block, PatchEmbed, get_2D_position_embeddings,\
13-
RMSNorm, SwishGLU
12+
from .transformer_utils import Block, PatchEmbed, get_2D_position_embeddings, RMSNorm, SwishGLU
1413

1514

1615

FeedbackPolicy/train/distributed.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
Util functions for setting up distributed training.
3+
Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py
4+
"""
5+
6+
import os
7+
import torch
8+
9+
try:
10+
import horovod.torch as hvd
11+
except ImportError:
12+
hvd = None
13+
14+
15+
def is_global_master(args):
16+
return args.rank == 0
17+
18+
19+
def is_local_master(args):
20+
return args.local_rank == 0
21+
22+
23+
def is_master(args, local=False):
24+
return is_local_master(args) if local else is_global_master(args)
25+
26+
27+
def is_using_horovod():
28+
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
29+
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
30+
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
31+
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
32+
if all([var in os.environ for var in ompi_vars]) or all(
33+
[var in os.environ for var in pmi_vars]
34+
):
35+
return True
36+
else:
37+
return False
38+
39+
40+
def is_using_distributed():
41+
if "WORLD_SIZE" in os.environ:
42+
return int(os.environ["WORLD_SIZE"]) > 1
43+
if "SLURM_NTASKS" in os.environ:
44+
return int(os.environ["SLURM_NTASKS"]) > 1
45+
return False
46+
47+
48+
def world_info_from_env():
49+
local_rank = 0
50+
for v in (
51+
"LOCAL_RANK",
52+
"MPI_LOCALRANKID",
53+
"SLURM_LOCALID",
54+
"OMPI_COMM_WORLD_LOCAL_RANK",
55+
):
56+
if v in os.environ:
57+
local_rank = int(os.environ[v])
58+
break
59+
global_rank = 0
60+
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
61+
if v in os.environ:
62+
global_rank = int(os.environ[v])
63+
break
64+
world_size = 1
65+
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
66+
if v in os.environ:
67+
world_size = int(os.environ[v])
68+
break
69+
70+
return local_rank, global_rank, world_size
71+
72+
73+
def init_distributed_device(args):
74+
# Distributed training = training on more than one GPU.
75+
# Works in both single and multi-node scenarios.
76+
args.distributed = False
77+
args.world_size = 1
78+
args.rank = 0 # global rank
79+
args.local_rank = 0
80+
if args.horovod:
81+
assert hvd is not None, "Horovod is not installed"
82+
hvd.init()
83+
args.local_rank = int(hvd.local_rank())
84+
args.rank = hvd.rank()
85+
args.world_size = hvd.size()
86+
args.distributed = True
87+
os.environ["LOCAL_RANK"] = str(args.local_rank)
88+
os.environ["RANK"] = str(args.rank)
89+
os.environ["WORLD_SIZE"] = str(args.world_size)
90+
elif is_using_distributed():
91+
if "SLURM_PROCID" in os.environ:
92+
# DDP via SLURM
93+
args.local_rank, args.rank, args.world_size = world_info_from_env()
94+
# SLURM var -> torch.distributed vars in case needed
95+
os.environ["LOCAL_RANK"] = str(args.local_rank)
96+
os.environ["RANK"] = str(args.rank)
97+
os.environ["WORLD_SIZE"] = str(args.world_size)
98+
torch.distributed.init_process_group(
99+
backend=args.dist_backend,
100+
init_method=args.dist_url,
101+
world_size=args.world_size,
102+
rank=args.rank,
103+
)
104+
else:
105+
# DDP via torchrun, torch.distributed.launch
106+
args.local_rank, _, _ = world_info_from_env()
107+
torch.distributed.init_process_group(
108+
backend=args.dist_backend, init_method=args.dist_url
109+
)
110+
args.world_size = torch.distributed.get_world_size()
111+
args.rank = torch.distributed.get_rank()
112+
args.distributed = True
113+
else:
114+
# needed to run on single gpu
115+
torch.distributed.init_process_group(
116+
backend=args.dist_backend,
117+
init_method=args.dist_url,
118+
world_size=1,
119+
rank=0,
120+
)
121+
122+
if torch.cuda.is_available():
123+
if args.distributed and not args.no_set_device_rank:
124+
device = "cuda:%d" % args.local_rank
125+
else:
126+
device = "cuda:0"
127+
torch.cuda.set_device(device)
128+
else:
129+
device = "cpu"
130+
args.device = device
131+
device = torch.device(device)
132+
return device

0 commit comments

Comments
 (0)