Skip to content

Commit

Permalink
Bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikram Voleti committed Jun 20, 2023
1 parent 45c51d1 commit 34e7627
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
14 changes: 7 additions & 7 deletions threestudio/models/guidance/zero123SD_vsd_guidance.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import cv2
import importlib
import numpy as np
import os
import random
from contextlib import contextmanager
from dataclasses import dataclass, field
from omegaconf import OmegaConf
from tqdm import tqdm

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
DDIMScheduler,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf
from tqdm import tqdm

import threestudio
from threestudio.models.prompt_processors.base import PromptProcessorOutput
Expand Down Expand Up @@ -663,8 +663,8 @@ def evaluate_guidance(self, cond, t_orig, latents, noise_pred):
for b, i in enumerate(idxs):
latents = latents_1step[b : b + 1]
c = {
"c_crossattn": [cond["c_crossattn"][0][b * 2 : b * 2 + 2]],
"c_concat": [cond["c_concat"][0][b * 2 : b * 2 + 2]],
"c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]],
"c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]],
}
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
# pred noise
Expand Down
19 changes: 10 additions & 9 deletions threestudio/models/guidance/zero123_vsd_guidance.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import cv2
import importlib
import numpy as np
import os
from dataclasses import dataclass
from omegaconf import OmegaConf
from tqdm import tqdm

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDIMScheduler
from omegaconf import OmegaConf
from tqdm import tqdm

import threestudio
from extern.ldm_zero123.modules.attention import BasicTransformerBlock, CrossAttention
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.base import BaseModule
from threestudio.utils.misc import C, cleanup
from threestudio.utils.typing import *

from extern.ldm_zero123.modules.attention import BasicTransformerBlock, CrossAttention


def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
Expand Down Expand Up @@ -310,7 +309,9 @@ def get_cond(
def extract_from_cond(self, cond, n_samples=1) -> dict:
only_cond = {}
for key in cond:
only_cond[key] = [torch.cat(n_samples * [cond[key][0][-len(cond[key][0])//2:]])]
only_cond[key] = [
torch.cat(n_samples * [cond[key][0][-len(cond[key][0]) // 2 :]])
]
return only_cond

def compute_grad_vsd(
Expand Down Expand Up @@ -496,8 +497,8 @@ def evaluate_guidance(self, cond, t_orig, latents, noise_pred):
for b, i in enumerate(idxs):
latents = latents_1step[b : b + 1]
c = {
"c_crossattn": [cond["c_crossattn"][0][b * 2 : b * 2 + 2]],
"c_concat": [cond["c_concat"][0][b * 2 : b * 2 + 2]],
"c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]],
"c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]],
}
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
# pred noise
Expand Down

0 comments on commit 34e7627

Please sign in to comment.