diff --git a/modules/ObjectDetector.py b/modules/ObjectDetector.py index 3e52556..cc898fe 100644 --- a/modules/ObjectDetector.py +++ b/modules/ObjectDetector.py @@ -1,9 +1,24 @@ from .object_detection_interface import ObjectDetectionInterface, preprocess_image +import numpy as np +from typing import List, Union class ObjectDetector: def __init__(self): self.detector = ObjectDetectionInterface() - def detect_objects(self, image): - preprocessed_image = preprocess_image(image) - return self.detector.detect(preprocessed_image) \ No newline at end of file + def detect_objects(self, images: Union[np.ndarray, List[np.ndarray]]): + try: + if isinstance(images, np.ndarray): + images = [images] + + preprocessed_images = [preprocess_image(img) for img in images] + batch_results = self.detector.detect_batch(preprocessed_images) + + return [self._post_process(result) for result in batch_results] + except Exception as e: + print(f"Error in object detection: {str(e)}") + return [] + + def _post_process(self, detection_result): + # Implement post-processing logic here (e.g., non-max suppression, filtering) + return detection_result diff --git a/modules/model_training.py b/modules/model_training.py index f5053e4..3e6a442 100644 --- a/modules/model_training.py +++ b/modules/model_training.py @@ -38,7 +38,7 @@ def __call__(self, x): x = DenseLayer(features=self.action_size)(x) return x -def create_train_state(rng, input_shape, action_size, learning_rate=0.0005): +def create_train_state(rng, input_shape, action_size, learning_rate=0.001): def forward_fn(x): model = ModelTrainer(num_layers=3, hidden_size=64, num_heads=4, dropout_rate=0.1) model.set_action_size(action_size) @@ -46,7 +46,15 @@ def forward_fn(x): transformed_forward = hk.transform(forward_fn) params = transformed_forward.init(rng, jnp.ones(input_shape)) - tx = adam(learning_rate) + + # Use Adam optimizer with learning rate schedule + schedule_fn = optax.exponential_decay(init_value=learning_rate, transition_steps=1000, decay_rate=0.9) + tx = optax.chain( + optax.clip_by_global_norm(1.0), # Gradient clipping + optax.scale_by_adam(), + optax.scale_by_schedule(schedule_fn) + ) + return train_state.TrainState.create(apply_fn=transformed_forward.apply, params=params, tx=tx) class ModelTrainerWrapper: @@ -58,6 +66,9 @@ def __init__(self, input_shape, n_actions): self.target_params = self.state.params self.forward = hk.transform(lambda x: ModelTrainer(num_layers=3, hidden_size=64, num_heads=4, dropout_rate=0.1)(x)) self.population = self._initialize_population() + self.best_loss = float('inf') + self.patience = 10 + self.patience_counter = 0 def _initialize_population(self): return [self._mutate_params(self.state.params) for _ in range(pop_size)] @@ -128,7 +139,20 @@ def update(self, batch): self.population = new_population self.state = train_state.TrainState.create(apply_fn=self.state.apply_fn, params=self.population[0], tx=self.state.tx) - return min(losses) + min_loss = min(losses) + + # Early stopping + if min_loss < self.best_loss: + self.best_loss = min_loss + self.patience_counter = 0 + else: + self.patience_counter += 1 + + if self.patience_counter >= self.patience: + print("Early stopping triggered") + return None + + return min_loss def _tournament_selection(self, population, fitnesses): tournament_size = random.randint(tournament_size_min, tournament_size_max) @@ -136,4 +160,9 @@ def _tournament_selection(self, population, fitnesses): return min(tournament, key=lambda x: x[1])[0] def update_target(self): - self.target_params = self.state.params \ No newline at end of file + self.target_params = self.state.params + + def compress_model(self): + # Implement model compression techniques here + # For example, pruning, quantization, or knowledge distillation + pass