-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from CoVital-Project/add_dataloader
Add dataloader and Semcova data
- Loading branch information
Showing
125 changed files
with
503 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
sample_data/**/*.mp4 filter=lfs diff=lfs merge=lfs -text | ||
sample_data/*/*.mp4 filter=lfs diff=lfs merge=lfs -text | ||
nemcova_data/**/*.mp4 filter=lfs diff=lfs merge=lfs -text | ||
nemcova_data/*/*.mp4 filter=lfs diff=lfs merge=lfs -text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# CoVital Pytorch Dataset and Dataloader | ||
|
||
In this folder there is the functionality to load videos from a folder following the structure described in [README](../sample_data/README.md). | ||
A Dataset instance can be created by: | ||
``` | ||
dataset = Spo2Dataset(data_path) | ||
``` | ||
This will itearate the folders at data_path and load their videos and compute the mean and std for each channel per frame and load the ground truth as labels. It also allows to store metadata for each video. Once the dataset is created, it will contain only the mean and std. This process is slow as to preserve memory, it does it frame per frame. Feel free to modify it if you know ways to speed up the process without causing memory issues. Using torchvision.io.read_video ran out of memory in a 32 gb RAM computer. | ||
|
||
Once the dataset is ready, it can be fed to a DataLoader object. | ||
|
||
``` | ||
dataloader = Spo2DataLoader(dataset, batch_size=4, collate_fn= Spo2DataLoader.collate_fn) | ||
``` | ||
|
||
The output needs to be batched tensors, and therefore they have to share the same length. Since we have videos of different lengths, it pads the shorted ones to fit the length of the longest one in each frame. This may be an issue for models which require the same length for all batches, but it is convinient for RNN models. The real length of each video is accessible for each batch. Each batch returns three variables, videos_batch, labels_batch and videos_lengths. | ||
|
||
There are two versions, one using Threading, but the performance was very similar for both. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from torch.utils.data import Dataset, DataLoader | ||
import cv2 | ||
import numpy as np | ||
import torchvision | ||
import os | ||
import json | ||
import torch | ||
from threading import Thread | ||
from queue import Queue | ||
import time | ||
|
||
#def timing(f): | ||
#def wrap(*args): | ||
#time1 = time.time() | ||
#ret = f(*args) | ||
#time2 = time.time() | ||
#print('{:s} function took {:.3f} ms'.format(f.__name__, (time2-time1)*1000.0)) | ||
|
||
#return ret | ||
#return wrap | ||
|
||
|
||
class Spo2Dataset(Dataset): | ||
"""Spo2Dataset dataset. | ||
It preprocess the data in order to create a Dataset with the average and std of each channel per frame. | ||
The process is slow so it may take a while to create the Dataset when first initated. | ||
""" | ||
#@timing | ||
def reshape(self, frame): | ||
return frame.reshape(-1,3) | ||
|
||
#@timing | ||
def mean_t(self, frame): | ||
return np.array([frame.mean(axis=0), frame.std(axis=0)]).T | ||
|
||
#@timing | ||
def transform(self,frame): | ||
frame = self.reshape(frame) | ||
ret = self.mean_t(frame) | ||
return ret | ||
|
||
#@timing | ||
def get_channels(self, frame, blue = 0, green = 1, red = 2): | ||
blue_channel = frame[:,:,blue] | ||
green_channel = frame[:,:,green] | ||
red_channel = frame[:,:,red] | ||
|
||
return blue_channel, green_channel, red_channel | ||
|
||
#@timing | ||
def mean_fast(self, blue_channel, green_channel, red_channel): | ||
blue_channel_mean = blue_channel.mean() | ||
green_channel_mean = green_channel.mean() | ||
red_channel_mean = red_channel.mean() | ||
|
||
return blue_channel_mean, green_channel_mean, red_channel_mean | ||
|
||
#@timing | ||
def std_fast(self, blue_channel, green_channel, red_channel): | ||
blue_channel_mean = blue_channel.std() | ||
green_channel_mean = green_channel.std() | ||
red_channel_mean = red_channel.std() | ||
|
||
return blue_channel_mean, green_channel_mean, red_channel_mean | ||
|
||
#@timing | ||
def transform_faster(self, frame): | ||
blue_channel, green_channel, red_channel = self.get_channels(frame) | ||
blue_channel_mean, green_channel_mean, red_channel_mean = self.mean_fast(blue_channel, green_channel, red_channel) | ||
blue_channel_std, green_channel_std, red_channel_std = self.std_fast(blue_channel, green_channel, red_channel) | ||
|
||
return np.array([[blue_channel_mean, blue_channel_std], | ||
[green_channel_mean, green_channel_std], | ||
[red_channel_mean, red_channel_std]]) | ||
|
||
def __init__(self, data_path): | ||
""" | ||
Args: | ||
data_path (string): Path to the data folder. | ||
""" | ||
self.data_path = data_path | ||
self.video_folders = [folder for folder in os.listdir(data_path) if os.path.isdir(os.path.join(data_path,folder))] | ||
self.videos_ppg = [] | ||
self.labels_list = [] | ||
self.meta_list = [] | ||
|
||
nb_video = 1 | ||
for video in self.video_folders: | ||
print("Loading video:", nb_video) | ||
nb_video += 1 | ||
ppg = [] | ||
video_path = os.path.join(self.data_path, video) | ||
video_file = os.path.join(video_path, [file_name for file_name in os.listdir(video_path) if file_name.endswith('mp4')][0]) | ||
vidcap = cv2.VideoCapture(video_file) | ||
meta = {} | ||
meta['video_fps'] = vidcap.get(cv2.CAP_PROP_FPS) | ||
(grabbed, frame) = vidcap.read() | ||
#frame_count = 0 | ||
while grabbed: | ||
frame = self.transform_faster(frame) | ||
ppg.append(frame) | ||
(grabbed, frame) = vidcap.read() | ||
#if(frame_count % 50 == 0): | ||
#print("Frame:", frame_count) | ||
#frame_count += 1 | ||
with open(os.path.join(video_path, 'gt.json'), 'r') as f: | ||
ground_truth = json.load(f) | ||
|
||
labels = torch.Tensor([int(ground_truth['SpO2']), int(ground_truth['HR'])]) | ||
self.videos_ppg.append(torch.Tensor(np.array(ppg))) | ||
self.meta_list.append(meta) | ||
self.labels_list.append(labels) | ||
def __len__(self): | ||
return len(self.video_folders) | ||
|
||
def __getitem__(self, idx): | ||
if torch.is_tensor(idx): | ||
idx = idx.tolist() | ||
return [self.videos_ppg[idx],self.meta_list[idx],self.labels_list[idx]] | ||
|
||
class Spo2DataLoader(DataLoader): | ||
def collate_fn(batch): | ||
videos_length = [element[0].shape[0] for element in batch] | ||
max_length = max(videos_length) | ||
videos_tensor = torch.FloatTensor(size=[len(videos_length),max_length, 3, 2]) | ||
labels_tensor = torch.FloatTensor(size=[len(videos_length), 2]) | ||
for i, element in enumerate(batch): | ||
padding = max_length-videos_length[i] | ||
if padding > 0: | ||
padding = torch.zeros([padding,3,2]) | ||
video = torch.cat([element[0], padding]) | ||
else: | ||
video = element[0] | ||
labels = element[2] | ||
videos_tensor[i] = video | ||
labels_tensor[i] = element[2] | ||
return videos_tensor, labels_tensor, torch.Tensor(videos_length) | ||
|
||
if __name__== "__main__": | ||
dataset = Spo2Dataset('sample_data') | ||
dataloader = Spo2DataLoader(dataset, batch_size=4, collate_fn= Spo2DataLoader.collate_fn) | ||
for videos_batch, labels_batch, videos_lengths in dataloader: | ||
print('Padded video (length, color, (mean,std)): ', videos_batch[0].shape) | ||
print('Video original length: ', videos_lengths[0]) | ||
print('Labels (so2, hr): ', labels_batch[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from torch.utils.data import Dataset, DataLoader | ||
import cv2 | ||
import numpy as np | ||
import torchvision | ||
import os | ||
import json | ||
import torch | ||
from threading import Thread | ||
from queue import Queue | ||
import time | ||
|
||
class VideoGet: | ||
""" | ||
Class that continuously gets frames from a VideoCapture object | ||
with a dedicated thread. | ||
It stores the frames in a list. | ||
It adds new frames only when there is less than a 100 in the "stack" to help overcome memory issues. | ||
""" | ||
|
||
def __init__(self, src=0, stack_size = 80): | ||
self.stream = cv2.VideoCapture(src) | ||
self.stack_size = stack_size | ||
(self.grabbed, self.frame) = self.stream.read() | ||
self.stopped = False | ||
self.frames = [self.frame] | ||
self.fps = self.stream.get(cv2.CAP_PROP_FPS) | ||
def start(self): | ||
Thread(target=self.get, args=()).start() | ||
return self | ||
|
||
def get(self): | ||
while not self.stopped: | ||
if not self.grabbed: | ||
self.stop() | ||
elif len(self.frames) < self.stack_size: | ||
(self.grabbed, self.frame) = self.stream.read() | ||
self.frames.append(self.frame) | ||
|
||
def stop(self): | ||
self.stopped = True | ||
|
||
class Spo2Dataset(Dataset): | ||
"""Spo2Dataset dataset. | ||
It preprocess the data in order to create a Dataset with the average and std of each channel per frame. | ||
The process is slow so it may take a while to create the Dataset when first initated. | ||
""" | ||
def transform(self,frame): | ||
frame = frame.reshape(len(frame),-1,3) | ||
return np.array([frame.mean(axis=1), frame.std(axis=1)]).reshape(len(frame),3,2) | ||
def __init__(self, data_path): | ||
""" | ||
Args: | ||
data_path (string): Path to the data folder. | ||
""" | ||
self.data_path = data_path | ||
self.video_folders = [folder for folder in os.listdir(data_path) if os.path.isdir(os.path.join(data_path,folder))] | ||
self.videos_ppg = [] | ||
self.labels_list = [] | ||
self.meta_list = [] | ||
|
||
for video in self.video_folders: | ||
ppg = [] | ||
video_path = os.path.join(self.data_path, video) | ||
video_file = os.path.join(video_path, [file_name for file_name in os.listdir(video_path) if file_name.endswith('mp4')][0]) | ||
vidcap = VideoGet(video_file).start() | ||
meta = {} | ||
meta['video_fps'] = vidcap.fps | ||
while True: | ||
if vidcap.stopped and len(vidcap.frames)==1: | ||
vidcap.stop() | ||
break | ||
if len(vidcap.frames)>1: | ||
frames = np.array(vidcap.frames[:-1]) | ||
vidcap.frames = vidcap.frames[len(frames):] | ||
frame = self.transform(frames) | ||
ppg.extend(frame) | ||
|
||
with open(os.path.join(video_path, 'gt.json'), 'r') as f: | ||
ground_truth = json.load(f) | ||
|
||
labels = torch.Tensor([int(ground_truth['SpO2']), int(ground_truth['HR'])]) | ||
self.videos_ppg.append(torch.Tensor(np.array(ppg))) | ||
self.meta_list.append(meta) | ||
self.labels_list.append(labels) | ||
def __len__(self): | ||
return len(self.video_folders) | ||
|
||
def __getitem__(self, idx): | ||
if torch.is_tensor(idx): | ||
idx = idx.tolist() | ||
return [self.videos_ppg[idx],self.meta_list[idx],self.labels_list[idx]] | ||
|
||
class Spo2DataLoader(DataLoader): | ||
def collate_fn(batch): | ||
videos_length = [element[0].shape[0] for element in batch] | ||
max_length = max(videos_length) | ||
videos_tensor = torch.FloatTensor(size=[len(videos_length),max_length, 3, 2]) | ||
labels_tensor = torch.FloatTensor(size=[len(videos_length), 2]) | ||
for i, element in enumerate(batch): | ||
padding = max_length-videos_length[i] | ||
if padding > 0: | ||
padding = torch.zeros([padding,3,2]) | ||
video = torch.cat([element[0], padding]) | ||
else: | ||
video = element[0] | ||
labels = element[2] | ||
videos_tensor[i] = video | ||
labels_tensor[i] = element[2] | ||
return videos_tensor, labels_tensor, torch.Tensor(videos_length) | ||
|
||
if __name__== "__main__": | ||
dataset = Spo2Dataset('sample_data') | ||
dataloader = Spo2DataLoader(dataset, batch_size=4, collate_fn= Spo2DataLoader.collate_fn) | ||
for videos_batch, labels_batch, videos_lengths in dataloader: | ||
print('Padded video (length, color, (mean,std)): ', videos_batch[0].shape) | ||
print('Video original length: ', videos_lengths[0]) | ||
print('Labels (so2, hr): ', labels_batch[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,5 +44,3 @@ | |
plt.plot(x, spo2_smooth) | ||
plt.show() | ||
|
||
|
||
|
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 96, "HR": 68, "VideoFilename": "data_08b.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "Vibe S1", "PhoneMake": "Lenovo"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 98, "HR": 59, "VideoFilename": "data_05b.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "Vibe S1", "PhoneMake": "Lenovo"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 96, "HR": 67, "VideoFilename": "data_07a.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "S750", "PhoneMake": "Lenovo"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 99, "HR": 53, "VideoFilename": "data_03b.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "Vibe S1", "PhoneMake": "Lenovo"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 94, "HR": 69, "VideoFilename": "data_09b.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "Vibe S1", "PhoneMake": "Lenovo"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 94, "HR": 72, "VideoFilename": "data_09a.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "Galaxy A3", "PhoneMake": "Samsung"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 98, "HR": 93, "VideoFilename": "data_04b.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "Vibe S1", "PhoneMake": "Lenovo"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 97, "HR": 82, "VideoFilename": "data_10b.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "Vibe S1", "PhoneMake": "Lenovo"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": "Unkown", "HR": 75, "VideoFilename": "data_15.mp4"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 97, "HR": 68, "VideoFilename": "data_01a.mp4"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"PhoneModel": "iPhone SE", "PhoneMake": "Apple"} |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SpO2": 97, "HR": 79, "VideoFilename": "data_01b.mp4"} |
Oops, something went wrong.