Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save faces image #415

Merged
merged 3 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion face_recognition/face_cluster_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def update_face_cluster(self):

if last_face_cluster_time and int(mtime) <= int(last_face_cluster_time.timestamp()):
continue
self._face_recognition_manager.face_cluster(repo_id)

try:
self._face_recognition_manager.face_cluster(repo_id)
except Exception as e:
logger.error("repo: %s, update face cluster error: %s" % (repo_id, e))

logger.info("Finish update face cluster")
33 changes: 25 additions & 8 deletions face_recognition/face_recognition_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
from datetime import datetime
import numpy as np
Expand All @@ -11,7 +10,7 @@
from seafevents.repo_metadata.constants import METADATA_OP_LIMIT
from seafevents.face_recognition.db import update_face_cluster_time, update_face_cluster_time
from seafevents.face_recognition.utils import get_faces_rows, get_cluster_by_center, b64encode_embeddings, \
b64decode_embeddings, get_faces_rows, get_face_embeddings
b64decode_embeddings, get_faces_rows, get_face_embeddings, get_image_face, save_face

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +55,7 @@ def init_face_recognition(self, repo_id):

for item in result:
obj_id = item['obj_id']
face_embeddings = item['embeddings']
face_embeddings = [face['embedding'] for face in item['faces']]
for row in obj_id_to_rows.get(obj_id, []):
row_id = row[METADATA_TABLE.columns.id.name]
updated_rows.append({
Expand All @@ -82,7 +81,7 @@ def face_cluster(self, repo_id):
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
update_face_cluster_time(self._db_session_class, repo_id, current_time)

sql = f'SELECT `{METADATA_TABLE.columns.id.name}`, `{METADATA_TABLE.columns.face_vectors.name}` FROM `{METADATA_TABLE.name}` WHERE `{METADATA_TABLE.columns.face_vectors.name}` IS NOT NULL'
sql = f'SELECT `{METADATA_TABLE.columns.id.name}`, `{METADATA_TABLE.columns.face_vectors.name}`, `{METADATA_TABLE.columns.obj_id.name}` FROM `{METADATA_TABLE.name}` WHERE `{METADATA_TABLE.columns.face_vectors.name}` IS NOT NULL'
query_result = query_metadata_rows(repo_id, self.metadata_server_api, sql)
if not query_result:
return
Expand All @@ -95,7 +94,9 @@ def face_cluster(self, repo_id):

vectors = []
row_ids = []
id_to_record = dict()
for item in query_result:
id_to_record[item[METADATA_TABLE.columns.id.name]] = item
row_id = item[METADATA_TABLE.columns.id.name]
face_vectors = b64decode_embeddings(item[METADATA_TABLE.columns.face_vectors.name])
for face_vector in face_vectors:
Expand Down Expand Up @@ -131,11 +132,27 @@ def face_cluster(self, repo_id):
continue

result = self.metadata_server_api.insert_rows(repo_id, faces_table_id, [face_row])
row_id = result.get('row_ids')[0]
face_row_id = result.get('row_ids')[0]
row_id_map = {
row_id: related_row_ids
face_row_id: related_row_ids
}
self.metadata_server_api.insert_link(repo_id, FACES_TABLE.link_id, faces_table_id, row_id_map)

need_delete_row_ids = [item[FACES_TABLE.columns.id.name] for item in old_cluster]
self.metadata_server_api.delete_rows(repo_id, faces_table_id, need_delete_row_ids)
face_image = None
for row_id in related_row_ids:
if row_ids.count(row_id) == 1:
record = id_to_record[row_id]
obj_id = record[METADATA_TABLE.columns.obj_id.name]
face_image = get_image_face(repo_id, obj_id, self.image_embedding_api, cluster_center.tolist())
break

if not face_image:
record = id_to_record[related_row_ids[0]]
obj_id = record[METADATA_TABLE.columns.obj_id.name]
face_image = get_image_face(repo_id, obj_id, self.image_embedding_api, cluster_center.tolist())

if not face_image:
continue

filename = f'{face_row_id}.jpg'
save_face(repo_id, face_image, filename)
49 changes: 47 additions & 2 deletions face_recognition/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import base64
import json
import io
import os
import posixpath

import numpy as np
from PIL import Image

from seaserv import seafile_api

from seafevents.repo_metadata.utils import FACES_TABLE, query_metadata_rows
from seafevents.repo_metadata.utils import FACES_TABLE, query_metadata_rows, get_file_content
from seafevents.repo_metadata.constants import FACE_EMBEDDING_DIM

FACES_TMP_DIR = '/tmp'
FACES_SAVE_PATH = '_Internal/Faces'


def feature_distance(feature1, feature2):
diff = np.subtract(feature1, feature2)
Expand Down Expand Up @@ -57,3 +66,39 @@ def get_face_embeddings(repo_id, image_embedding_api, obj_ids):
embeddings.append(query_results)

return embeddings


def get_image_face(repo_id, obj_id, image_embedding_api, center):
result = image_embedding_api.face_embeddings(repo_id, [obj_id]).get('data', [])
if not result:
return None

if len(result) == 1:
return get_face_by_box(repo_id, obj_id, result[0]['faces'][0]['box'])

faces = result[0]['faces']
sim = [feature_distance(center, face['embedding']) for face in faces]
return get_face_by_box(repo_id, obj_id, faces[min(sim)]['box'])


def get_face_by_box(repo_id, obj_id, box):
content = get_file_content(repo_id, obj_id)
if not content:
return None

image = Image.open(io.BytesIO(content))
cropped_image = image.crop((box[0], box[1], box[2], box[3]))
output_buffer = io.BytesIO()
cropped_image.save(output_buffer, format='jpeg')
output_buffer.seek(0)

return output_buffer.getvalue()


def save_face(repo_id, image, filename):
tmp_content_path = posixpath.join(FACES_TMP_DIR, filename)
with open(tmp_content_path, 'wb') as f:
f.write(image)

seafile_api.post_file(repo_id, tmp_content_path, FACES_SAVE_PATH, filename, 'system')
os.remove(tmp_content_path)
2 changes: 1 addition & 1 deletion repo_metadata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def add_file_details(repo_id, obj_ids, metadata_server_api, image_embedding_api=
if image_embedding_api and not row.get(METADATA_TABLE.columns.face_vectors.name):
result = image_embedding_api.face_embeddings(repo_id, [obj_id]).get('data', [])
if result:
face_embeddings = result[0]['embeddings']
face_embeddings = [face['embedding'] for face in result[0]['faces']]
update_row[METADATA_TABLE.columns.face_vectors.name] = b64encode_embeddings(face_embeddings)
elif file_type == '_video':
update_row = add_video_detail_row(row_id, content, has_capture_time_column)
Expand Down
Loading