Skip to content

Pytorch 1.11 #894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: pytorch-1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Introduction

### Good news! This repo supports pytorch-1.0 now!!! We borrowed some code and techniques from [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark).
### Good news! This repo supports pytorch-1.11 now!!! We borrowed some code and techniques from [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark).

This project is a *faster* pytorch implementation of faster R-CNN, aimed to accelerating the training of faster R-CNN object detection models. Recently, there are a number of good implementations:

Expand Down Expand Up @@ -34,7 +34,8 @@ During our implementing, we referred the above implementations, especially [long
- [x] Add deformable pooling layer (mainly supported by [Xander](https://github.com/xanderchf)).
- [x] Support pytorch-0.4.0 (go to master branch).
- [x] Support tensorboardX.
- [x] Support pytorch-1.0 (this branch).
- [x] Support pytorch-1.0.
- [x] Support pytorch-1.11 (this branch).

## Other Implementations

Expand Down Expand Up @@ -108,8 +109,8 @@ cd faster-rcnn.pytorch && mkdir data
### prerequisites

* Python 2.7 or 3.6
* Pytorch 1.0 (for Pytorch 0.4.0 go to master branch)
* CUDA 8.0 or higher
* Pytorch 1.11+ (for Pytorch 0.4.0 go to master branch)
* CUDA 11.1 or higher

### Data Preparation

Expand Down
17 changes: 7 additions & 10 deletions lib/model/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <ATen/ceil_div.h>

// TODO make it in a common file
#define CUDA_1D_KERNEL_LOOP(i, n) \
Expand Down Expand Up @@ -272,11 +269,11 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div(output_size, 512L), 4096L));
dim3 block(512);

if (output.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return output;
}

Expand All @@ -294,7 +291,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
rois.contiguous().data<scalar_t>(),
output.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return output;
}

Expand All @@ -317,12 +314,12 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::ceil_div(grad.numel(), 512L), 4096L));
dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

Expand All @@ -341,6 +338,6 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
17 changes: 7 additions & 10 deletions lib/model/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <ATen/ceil_div.h>


// TODO make it in a common file
Expand Down Expand Up @@ -126,11 +123,11 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div(output_size, 512L), 4096L));
dim3 block(512);

if (output.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}

Expand All @@ -148,7 +145,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
output.data<scalar_t>(),
argmax.data<int>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}

Expand All @@ -173,12 +170,12 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::ceil_div(grad.numel(), 512L), 4096L));
dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

Expand All @@ -197,6 +194,6 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
22 changes: 9 additions & 13 deletions lib/model/csrc/cuda/nms.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include <ATen/ceil_div.h>
#include <ATen/cuda/ThrustAllocator.h>

#include <vector>
#include <iostream>
Expand Down Expand Up @@ -61,7 +60,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
t |= 1ULL << i;
}
}
const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
Expand All @@ -76,28 +75,25 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {

int boxes_num = boxes.size(0);

const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock);

scalar_t* boxes_dev = boxes_sorted.data<scalar_t>();

THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState

unsigned long long* mask_dev = NULL;
//THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
// boxes_num * col_blocks * sizeof(unsigned long long)));

mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long));

dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
THCCeilDiv(boxes_num, threadsPerBlock));
dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock),
at::ceil_div(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
nms_kernel<<<blocks, threads>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);

std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
THCudaCheck(cudaMemcpy(&mask_host[0],
C10_CUDA_CHECK(cudaMemcpy(&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost));
Expand All @@ -122,7 +118,7 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
}
}

THCudaFree(state, mask_dev);
c10::cuda::CUDACachingAllocator::raw_delete(mask_dev);
// TODO improve this part
return std::get<0>(order_t.index({
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
Expand Down
2 changes: 1 addition & 1 deletion lib/model/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def cfg_from_file(filename):
"""Load a config file and merge it into the default options."""
import yaml
with open(filename, 'r') as f:
yaml_cfg = edict(yaml.load(f))
yaml_cfg = edict(yaml.safe_load(f))

_merge_a_into_b(yaml_cfg, __C)

Expand Down
1 change: 0 additions & 1 deletion lib/pycocotools/UPSTREAM_REV

This file was deleted.

1 change: 0 additions & 1 deletion lib/pycocotools/__init__.py

This file was deleted.

Loading