Skip to content

Commit 3e81ed3

Browse files
committed
bodypix
1 parent 5730e0d commit 3e81ed3

File tree

6 files changed

+404
-2
lines changed

6 files changed

+404
-2
lines changed

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
"dependencies": {
99
"@headlessui/react": "^1.3.0",
1010
"@heroicons/react": "^1.0.2",
11+
"@tensorflow-models/body-pix": "^2.2.0",
1112
"@tensorflow-models/face-landmarks-detection": "^0.0.3",
13+
"@tensorflow-models/posenet": "^2.2.2",
1214
"@tensorflow/tfjs": "^3.8.0",
1315
"@tensorflow/tfjs-backend-webgl": "^3.10.0",
1416
"@tensorflow/tfjs-converter": "^3.10.0",

src/pages/body-pix/index.js

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
import "@tensorflow/tfjs-backend-webgl";
2+
import * as bodyPix from "@tensorflow-models/body-pix";
3+
import Stats from "stats.js";
4+
import { useEffect, useRef } from "react";
5+
import * as partColorScales from '/src/utils/part-color-scales';
6+
import {drawKeypoints, drawSkeleton} from '/src/utils/draw';
7+
8+
export default function BodyPix() {
9+
let dat, stats
10+
const state = {
11+
video: null,
12+
stream: null,
13+
net: null,
14+
videoConstraints: {},
15+
// Triggers the TensorFlow model to reload
16+
changingArchitecture: false,
17+
changingMultiplier: false,
18+
changingStride: false,
19+
changingResolution: false,
20+
changingQuantBytes: false,
21+
};
22+
const guiState = {
23+
algorithm: "multi-person-instance",
24+
estimate: "partmap",
25+
camera: null,
26+
flipHorizontal: true,
27+
input: {
28+
architecture: "MobileNetV1",
29+
outputStride: 16,
30+
internalResolution: "low",
31+
multiplier: 0.5,
32+
quantBytes: 2,
33+
},
34+
multiPersonDecoding: {
35+
maxDetections: 5,
36+
scoreThreshold: 0.3,
37+
nmsRadius: 20,
38+
numKeypointForMatching: 17,
39+
refineSteps: 10,
40+
},
41+
segmentation: {
42+
segmentationThreshold: 0.7,
43+
effect: "mask",
44+
maskBackground: true,
45+
opacity: 0.7,
46+
backgroundBlurAmount: 3,
47+
maskBlurAmount: 0,
48+
edgeBlurAmount: 3,
49+
},
50+
partMap: {
51+
colorScale: "rainbow",
52+
effect: "partMap",
53+
segmentationThreshold: 0.5,
54+
opacity: 0.9,
55+
blurBodyPartAmount: 3,
56+
bodyPartEdgeBlurAmount: 3,
57+
},
58+
showFps: false,
59+
};
60+
61+
const canvas = useRef()
62+
const videoElement = useRef()
63+
64+
useEffect(async () => {
65+
stats = new Stats();
66+
dat = require('dat.gui')
67+
await loadBodyPix();
68+
await loadVideo(guiState.camera);
69+
70+
let cameras = await getVideoInputs();
71+
72+
// setupFPS();
73+
// setupGui(cameras);
74+
75+
segmentBodyInRealTime();
76+
});
77+
78+
return (
79+
<div className="max-w-5xl">
80+
<video ref={videoElement} playsinline style={{ display: 'none' }}></video>
81+
<canvas ref={canvas}></canvas>
82+
</div>
83+
);
84+
85+
async function getVideoInputs() {
86+
if (!window.navigator.mediaDevices || !window.navigator.mediaDevices.enumerateDevices) {
87+
console.log('enumerateDevices() not supported.');
88+
return [];
89+
}
90+
91+
const devices = await window.navigator.mediaDevices.enumerateDevices();
92+
93+
const videoDevices = devices.filter(device => device.kind === 'videoinput');
94+
95+
return videoDevices;
96+
}
97+
98+
async function setupCamera(cameraLabel) {
99+
if (!window.navigator.mediaDevices || !window.navigator.mediaDevices.getUserMedia) {
100+
throw new Error(
101+
"Browser API window.navigator.mediaDevices.getUserMedia not available"
102+
);
103+
}
104+
105+
// stopExistingVideoCapture();
106+
107+
const videoConstraints = { deviceId: null, facingMode: 'user', width: 400,
108+
height: 400,};
109+
110+
const stream = await window.navigator.mediaDevices.getUserMedia({
111+
audio: false,
112+
video: videoConstraints
113+
});
114+
videoElement.current.srcObject = stream;
115+
116+
return new Promise((resolve) => {
117+
videoElement.current.onloadedmetadata = () => {
118+
videoElement.current.width = videoElement.current.videoWidth;
119+
videoElement.current.height = videoElement.current.videoHeight;
120+
resolve(videoElement.current);
121+
};
122+
});
123+
}
124+
125+
async function loadVideo(cameraLabel) {
126+
await setupCamera(cameraLabel);
127+
videoElement.current.play();
128+
}
129+
130+
async function loadBodyPix() {
131+
state.net = await bodyPix.load({
132+
architecture: guiState.input.architecture,
133+
outputStride: guiState.input.outputStride,
134+
multiplier: guiState.input.multiplier,
135+
quantBytes: guiState.input.quantBytes,
136+
});
137+
}
138+
139+
function drawPoses(personOrPersonPartSegmentation, flipHorizontally, ctx) {
140+
if (Array.isArray(personOrPersonPartSegmentation)) {
141+
personOrPersonPartSegmentation.forEach(personSegmentation => {
142+
let pose = personSegmentation.pose;
143+
if (flipHorizontally) {
144+
pose = bodyPix.flipPoseHorizontal(pose, personSegmentation.width);
145+
}
146+
drawKeypoints(pose.keypoints, 0.1, ctx);
147+
drawSkeleton(pose.keypoints, 0.1, ctx);
148+
});
149+
} else {
150+
personOrPersonPartSegmentation.allPoses.forEach(pose => {
151+
if (flipHorizontally) {
152+
pose = bodyPix.flipPoseHorizontal(
153+
pose, personOrPersonPartSegmentation.width);
154+
}
155+
drawKeypoints(pose.keypoints, 0.1, ctx);
156+
drawSkeleton(pose.keypoints, 0.1, ctx);
157+
})
158+
}
159+
}
160+
161+
function segmentBodyInRealTime() {
162+
async function bodySegmentationFrame() {
163+
// if changing the model or the camera, wait a second for it to complete
164+
// then try again.
165+
if (
166+
state.changingArchitecture ||
167+
state.changingMultiplier ||
168+
state.changingCamera ||
169+
state.changingStride ||
170+
state.changingQuantBytes
171+
) {
172+
console.log("load model...");
173+
loadBodyPix();
174+
state.changingArchitecture = false;
175+
state.changingMultiplier = false;
176+
state.changingStride = false;
177+
state.changingQuantBytes = false;
178+
}
179+
180+
// Begin monitoring code for frames per second
181+
stats.begin();
182+
183+
const flipHorizontally = guiState.flipHorizontal;
184+
185+
if (!canvas.current) return
186+
if (!videoElement.current) return
187+
switch (guiState.estimate) {
188+
case "segmentation":
189+
const multiPersonSegmentation = await estimateSegmentation();
190+
switch (guiState.segmentation.effect) {
191+
case "mask":
192+
const ctx = canvas.current.getContext("2d");
193+
const foregroundColor = { r: 255, g: 255, b: 255, a: 255 };
194+
const backgroundColor = { r: 0, g: 0, b: 0, a: 255 };
195+
const mask = bodyPix.toMask(
196+
multiPersonSegmentation,
197+
foregroundColor,
198+
backgroundColor,
199+
true
200+
);
201+
202+
bodyPix.drawMask(
203+
canvas.current,
204+
videoElement.current,
205+
mask,
206+
guiState.segmentation.opacity,
207+
guiState.segmentation.maskBlurAmount,
208+
flipHorizontally
209+
);
210+
drawPoses(multiPersonSegmentation, flipHorizontally, ctx);
211+
break;
212+
case "bokeh":
213+
bodyPix.drawBokehEffect(
214+
canvas.current,
215+
videoElement.current,
216+
multiPersonSegmentation,
217+
+guiState.segmentation.backgroundBlurAmount,
218+
guiState.segmentation.edgeBlurAmount,
219+
flipHorizontally
220+
);
221+
break;
222+
}
223+
224+
break;
225+
case "partmap":
226+
const ctx = canvas.current.getContext("2d");
227+
const multiPersonPartSegmentation = await estimatePartSegmentation();
228+
const coloredPartImageData = bodyPix.toColoredPartMask(
229+
multiPersonPartSegmentation,
230+
partColorScales[guiState.partMap.colorScale]
231+
);
232+
233+
const maskBlurAmount = 0;
234+
switch (guiState.partMap.effect) {
235+
case "pixelation":
236+
const pixelCellWidth = 10.0;
237+
238+
bodyPix.drawPixelatedMask(
239+
canvas.current,
240+
videoElement.current,
241+
coloredPartImageData,
242+
guiState.partMap.opacity,
243+
maskBlurAmount,
244+
flipHorizontally,
245+
pixelCellWidth
246+
);
247+
break;
248+
case "partMap":
249+
if(!videoElement.current) return
250+
bodyPix.drawMask(
251+
canvas.current,
252+
videoElement.current,
253+
coloredPartImageData,
254+
guiState.opacity,
255+
maskBlurAmount,
256+
flipHorizontally
257+
);
258+
break;
259+
case "blurBodyPart":
260+
const blurBodyPartIds = [0, 1];
261+
bodyPix.blurBodyPart(
262+
canvas.current,
263+
videoElement.current,
264+
multiPersonPartSegmentation,
265+
blurBodyPartIds,
266+
guiState.partMap.blurBodyPartAmount,
267+
guiState.partMap.edgeBlurAmount,
268+
flipHorizontally
269+
);
270+
}
271+
drawPoses(multiPersonPartSegmentation, flipHorizontally, ctx);
272+
break;
273+
default:
274+
break;
275+
}
276+
277+
// End monitoring code for frames per second
278+
stats.end();
279+
280+
requestAnimationFrame(bodySegmentationFrame);
281+
}
282+
283+
bodySegmentationFrame();
284+
}
285+
286+
async function estimatePartSegmentation() {
287+
switch (guiState.algorithm) {
288+
case 'multi-person-instance':
289+
return await state.net.segmentMultiPersonParts(videoElement.current, {
290+
internalResolution: guiState.input.internalResolution,
291+
segmentationThreshold: guiState.segmentation.segmentationThreshold,
292+
maxDetections: guiState.multiPersonDecoding.maxDetections,
293+
scoreThreshold: guiState.multiPersonDecoding.scoreThreshold,
294+
nmsRadius: guiState.multiPersonDecoding.nmsRadius,
295+
numKeypointForMatching:
296+
guiState.multiPersonDecoding.numKeypointForMatching,
297+
refineSteps: guiState.multiPersonDecoding.refineSteps
298+
});
299+
case 'person':
300+
return await state.net.segmentPersonParts(videoElement.current, {
301+
internalResolution: guiState.input.internalResolution,
302+
segmentationThreshold: guiState.segmentation.segmentationThreshold,
303+
maxDetections: guiState.multiPersonDecoding.maxDetections,
304+
scoreThreshold: guiState.multiPersonDecoding.scoreThreshold,
305+
nmsRadius: guiState.multiPersonDecoding.nmsRadius,
306+
});
307+
default:
308+
break;
309+
};
310+
return multiPersonPartSegmentation;
311+
}
312+
}

src/pages/index.js

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ export default function Home() {
55
return (
66
<div className="flex flex-col justify-center items-center h-48 w-full gap-y-8">
77
<Link href="/style-transfer" >
8-
<a className="p-4 text-lg bg-primary rounded text-white shadow-md w-7/12 text-center">Get Started With Style Transfer</a>
8+
<a className="p-4 text-lg bg-primary rounded text-white shadow-md w-7/12 text-center">Style Transfer</a>
99
</Link>
1010
<Link href="/face-landmarks-detection" >
11-
<a className="p-4 text-lg bg-secondary rounded text-white shadow-md w-7/12 text-center">Get Started With <span className="font-bold">Face Landmarks Detection</span></a>
11+
<a className="p-4 text-lg bg-secondary rounded text-white shadow-md w-7/12 text-center">Face Landmarks Detection</a>
12+
</Link>
13+
<Link href="/body-pix" >
14+
<a className="p-4 text-lg bg-secondary rounded text-white shadow-md w-7/12 text-center">Body Pix</a>
1215
</Link>
1316
</div>
1417
)

src/utils/draw.js

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import * as posenet from '@tensorflow-models/posenet';
2+
3+
const COLOR = 'aqua';
4+
const BOUNDING_BOX_COLOR = 'red';
5+
const LINE_WIDTH = 2;
6+
7+
export function drawPoint(ctx, y, x, r, color) {
8+
ctx.beginPath();
9+
ctx.arc(x, y, r, 0, 2 * Math.PI);
10+
ctx.fillStyle = color;
11+
ctx.fill();
12+
}
13+
14+
export function drawKeypoints(keypoints, minConfidence, ctx, scale = 1) {
15+
for (let i = 0; i < keypoints.length; i++) {
16+
const keypoint = keypoints[i];
17+
18+
if (keypoint.score < minConfidence) {
19+
continue;
20+
}
21+
22+
const {y, x} = keypoint.position;
23+
drawPoint(ctx, y * scale, x * scale, 3, COLOR);
24+
}
25+
}
26+
27+
export function drawSkeleton(keypoints, minConfidence, ctx, scale = 1) {
28+
const adjacentKeyPoints =
29+
posenet.getAdjacentKeyPoints(keypoints, minConfidence);
30+
31+
function toTuple({y, x}) {
32+
return [y, x];
33+
}
34+
35+
adjacentKeyPoints.forEach((keypoints) => {
36+
drawSegment(
37+
toTuple(keypoints[0].position), toTuple(keypoints[1].position), COLOR,
38+
scale, ctx);
39+
});
40+
}
41+
42+
export function drawSegment([ay, ax], [by, bx], color, scale, ctx) {
43+
ctx.beginPath();
44+
ctx.moveTo(ax * scale, ay * scale);
45+
ctx.lineTo(bx * scale, by * scale);
46+
ctx.lineWidth = LINE_WIDTH;
47+
ctx.strokeStyle = color;
48+
ctx.stroke();
49+
}

0 commit comments

Comments
 (0)