Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated to tensorflow 2- #7

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 33 additions & 47 deletions convert_weights.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
"""Convert weights from kinetics and imagenet pretrained model
rebuild_ckpoint_kinetics: reconstruct an I3D model based on the pure
tensorflow framework, where weights are initialized from the kinetics-i3d model
"""

import os
import sys
import numpy as np
import tensorflow as tf

Expand All @@ -28,51 +22,43 @@
}

def rebuild_ckpoint_kinetics(checkpoint_dir, save_path):
"""rebuild the checkpoint from kinetics-i3d model
Inception-v1 inflated 3D ConvNet
"""
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
raw_var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
for k, v in KINETICS_NAME_MAP.items():
var_name = var_name.replace(k, v)
print(var_name, raw_var.shape)
var = tf.Variable(raw_var, name=var_name)
var_list = {} # To store the variables
for var_name, _ in tf.train.list_variables(checkpoint_dir):
raw_var = tf.train.load_variable(checkpoint_dir, var_name)
for k, v in KINETICS_NAME_MAP.items():
var_name = var_name.replace(k, v)
print(var_name, raw_var.shape)
var_list[var_name] = tf.Variable(raw_var, name=var_name)

saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, save_path)
# Save new variables
ckpt = tf.train.Checkpoint(**var_list)
ckpt.save(save_path)

def rebuild_ckpoint_imagenet(checkpoint_dir, save_path):
"""rebuild the checkpoint from imagenet 2d model
Inception-v2 inflated 3d ConvNet
"""
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
var_list = {} # To store the variables
fg = True
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
raw_var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
if var_name.startswith('InceptionV2/Conv2d_1a_7x7'):
if fg:
var_name = 'v/SenseTime_I3D_V2/Conv3d_1a_7x7x7/kernel'
raw_var = np.random.normal(0.0, 1.0, (7, 7, 7, 3, 64)) / 7.0
fg = False
else:
continue
elif var_name.find('weights') > -1:
kernel = raw_var.shape[0]
res = [raw_var for i in range(kernel)]
raw_var = np.stack(res, axis=0)
raw_var = raw_var / (kernel * 1.0)
for k, v in IMAGENET_NAME_MAP.items():
var_name = var_name.replace(k, v)
print(var_name, raw_var.shape)
var = tf.Variable(raw_var, name=var_name)
for var_name, _ in tf.train.list_variables(checkpoint_dir):
raw_var = tf.train.load_variable(checkpoint_dir, var_name)
if var_name.startswith('InceptionV2/Conv2d_1a_7x7'):
if fg:
var_name = 'v/SenseTime_I3D_V2/Conv3d_1a_7x7x7/kernel'
raw_var = np.random.normal(0.0, 1.0, (7, 7, 7, 3, 64)) / 7.0
fg = False
else:
continue
elif var_name.find('weights') > -1:
kernel = raw_var.shape[0]
res = [raw_var for i in range(kernel)]
raw_var = np.stack(res, axis=0)
raw_var = raw_var / (kernel * 1.0)
for k, v in IMAGENET_NAME_MAP.items():
var_name = var_name.replace(k, v)
print(var_name, raw_var.shape)
var_list[var_name] = tf.Variable(raw_var, name=var_name)

saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, save_path)
# Save new variables
ckpt = tf.train.Checkpoint(**var_list)
ckpt.save(save_path)

def main():
checkpoint_dir = './kinetics-i3d/data/checkpoints/rgb_imagenet/model.ckpt'
Expand All @@ -81,4 +67,4 @@ def main():
# rebuild_ckpoint_imagenet(checkpoint_dir, './kinetics-i3d/data/inceptionv2_i3d/model')

if __name__ == '__main__':
main()
main()