Skip to content

Commit

Permalink
Merge pull request #10 from VishwamAI/devin/hardware-path/27644
Browse files Browse the repository at this point in the history
Enhancements to ObjectDetector and Model Training
  • Loading branch information
kasinadhsarma authored Aug 2, 2024
2 parents 49f45b4 + 132f069 commit 5ec7a72
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
21 changes: 18 additions & 3 deletions modules/ObjectDetector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
from .object_detection_interface import ObjectDetectionInterface, preprocess_image
import numpy as np
from typing import List, Union

class ObjectDetector:
def __init__(self):
self.detector = ObjectDetectionInterface()

def detect_objects(self, image):
preprocessed_image = preprocess_image(image)
return self.detector.detect(preprocessed_image)
def detect_objects(self, images: Union[np.ndarray, List[np.ndarray]]):
try:
if isinstance(images, np.ndarray):
images = [images]

preprocessed_images = [preprocess_image(img) for img in images]
batch_results = self.detector.detect_batch(preprocessed_images)

return [self._post_process(result) for result in batch_results]
except Exception as e:
print(f"Error in object detection: {str(e)}")
return []

def _post_process(self, detection_result):
# Implement post-processing logic here (e.g., non-max suppression, filtering)
return detection_result
37 changes: 33 additions & 4 deletions modules/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,23 @@ def __call__(self, x):
x = DenseLayer(features=self.action_size)(x)
return x

def create_train_state(rng, input_shape, action_size, learning_rate=0.0005):
def create_train_state(rng, input_shape, action_size, learning_rate=0.001):
def forward_fn(x):
model = ModelTrainer(num_layers=3, hidden_size=64, num_heads=4, dropout_rate=0.1)
model.set_action_size(action_size)
return model(x)

transformed_forward = hk.transform(forward_fn)
params = transformed_forward.init(rng, jnp.ones(input_shape))
tx = adam(learning_rate)

# Use Adam optimizer with learning rate schedule
schedule_fn = optax.exponential_decay(init_value=learning_rate, transition_steps=1000, decay_rate=0.9)
tx = optax.chain(
optax.clip_by_global_norm(1.0), # Gradient clipping
optax.scale_by_adam(),
optax.scale_by_schedule(schedule_fn)
)

return train_state.TrainState.create(apply_fn=transformed_forward.apply, params=params, tx=tx)

class ModelTrainerWrapper:
Expand All @@ -58,6 +66,9 @@ def __init__(self, input_shape, n_actions):
self.target_params = self.state.params
self.forward = hk.transform(lambda x: ModelTrainer(num_layers=3, hidden_size=64, num_heads=4, dropout_rate=0.1)(x))
self.population = self._initialize_population()
self.best_loss = float('inf')
self.patience = 10
self.patience_counter = 0

def _initialize_population(self):
return [self._mutate_params(self.state.params) for _ in range(pop_size)]
Expand Down Expand Up @@ -128,12 +139,30 @@ def update(self, batch):
self.population = new_population
self.state = train_state.TrainState.create(apply_fn=self.state.apply_fn, params=self.population[0], tx=self.state.tx)

return min(losses)
min_loss = min(losses)

# Early stopping
if min_loss < self.best_loss:
self.best_loss = min_loss
self.patience_counter = 0
else:
self.patience_counter += 1

if self.patience_counter >= self.patience:
print("Early stopping triggered")
return None

return min_loss

def _tournament_selection(self, population, fitnesses):
tournament_size = random.randint(tournament_size_min, tournament_size_max)
tournament = random.sample(list(zip(population, fitnesses)), tournament_size)
return min(tournament, key=lambda x: x[1])[0]

def update_target(self):
self.target_params = self.state.params
self.target_params = self.state.params

def compress_model(self):
# Implement model compression techniques here
# For example, pruning, quantization, or knowledge distillation
pass

0 comments on commit 5ec7a72

Please sign in to comment.