Skip to content

Commit

Permalink
Move ImageNet metadata (aka info) files to timm/data/_info. Add helpe…
Browse files Browse the repository at this point in the history
…r classes to make info available for labelling. Update inference.py for first use.
  • Loading branch information
rwightman committed Feb 7, 2023
1 parent 89b0452 commit 0f2803d
Show file tree
Hide file tree
Showing 22 changed files with 65,544 additions and 15 deletions.
5 changes: 3 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include timm/models/pruned/*.txt

include timm/models/_pruned/*.txt
include timm/data/_info/*.txt
include timm/data/_info/*.json
56 changes: 47 additions & 9 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pandas as pd
import torch

from timm.data import create_dataset, create_loader, resolve_data_config
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
Expand Down Expand Up @@ -46,6 +46,7 @@

_FMT_EXT = {
'json': '.json',
'json-record': '.json',
'json-split': '.json',
'parquet': '.parquet',
'csv': '.csv',
Expand Down Expand Up @@ -122,7 +123,7 @@
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")

parser.add_argument('--results-dir',type=str, default=None,
parser.add_argument('--results-dir', type=str, default=None,
help='folder for output results')
parser.add_argument('--results-file', type=str, default=None,
help='results filename (relative to results-dir)')
Expand All @@ -134,14 +135,20 @@
metavar='N', help='Top-k to output to CSV')
parser.add_argument('--fullname', action='store_true', default=False,
help='use full sample name in output (not just basename).')
parser.add_argument('--filename-col', default='filename',
parser.add_argument('--filename-col', type=str, default='filename',
help='name for filename / sample name column')
parser.add_argument('--index-col', default='index',
parser.add_argument('--index-col', type=str, default='index',
help='name for output indices column(s)')
parser.add_argument('--output-col', default=None,
parser.add_argument('--label-col', type=str, default='label',
help='name for output indices column(s)')
parser.add_argument('--output-col', type=str, default=None,
help='name for logit/probs output column(s)')
parser.add_argument('--output-type', default='prob',
parser.add_argument('--output-type', type=str, default='prob',
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
parser.add_argument('--label-type', type=str, default='description',
help='type of label to output, one of "none", "name", "description", "detailed"')
parser.add_argument('--include-index', action='store_true', default=False,
help='include the class index in results')
parser.add_argument('--exclude-output', action='store_true', default=False,
help='exclude logits/probs from results, just indices. topk must be set !=0.')

Expand Down Expand Up @@ -237,10 +244,26 @@ def main():
**data_config,
)

to_label = None
if args.label_type in ('name', 'description', 'detail'):
imagenet_subset = infer_imagenet_subset(model)
if imagenet_subset is not None:
dataset_info = ImageNetInfo(imagenet_subset)
if args.label_type == 'name':
to_label = lambda x: dataset_info.index_to_label_name(x)
elif args.label_type == 'detail':
to_label = lambda x: dataset_info.index_to_description(x, detailed=True)
else:
to_label = lambda x: dataset_info.index_to_description(x)
to_label = np.vectorize(to_label)
else:
_logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.")

top_k = min(args.topk, args.num_classes)
batch_time = AverageMeter()
end = time.time()
all_indices = []
all_labels = []
all_outputs = []
use_probs = args.output_type == 'prob'
with torch.no_grad():
Expand All @@ -254,7 +277,12 @@ def main():

if top_k:
output, indices = output.topk(top_k)
all_indices.append(indices.cpu().numpy())
np_indices = indices.cpu().numpy()
if args.include_index:
all_indices.append(np_indices)
if to_label is not None:
np_labels = to_label(np_indices)
all_labels.append(np_labels)

all_outputs.append(output.cpu().numpy())

Expand All @@ -267,6 +295,7 @@ def main():
batch_idx, len(loader), batch_time=batch_time))

all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
all_labels = np.concatenate(all_labels, axis=0) if all_labels else None
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
filenames = loader.dataset.filenames(basename=not args.fullname)

Expand All @@ -276,13 +305,20 @@ def main():
if all_indices is not None:
for i in range(all_indices.shape[-1]):
data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
if all_labels is not None:
for i in range(all_labels.shape[-1]):
data_dict[f'{args.label_col}_{i}'] = all_labels[:, i]
for i in range(all_outputs.shape[-1]):
data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
else:
if all_indices is not None:
if all_indices.shape[-1] == 1:
all_indices = all_indices.squeeze(-1)
data_dict[args.index_col] = list(all_indices)
if all_labels is not None:
if all_labels.shape[-1] == 1:
all_labels = all_labels.squeeze(-1)
data_dict[args.label_col] = list(all_labels)
if all_outputs.shape[-1] == 1:
all_outputs = all_outputs.squeeze(-1)
data_dict[output_col] = list(all_outputs)
Expand All @@ -291,7 +327,7 @@ def main():

results_filename = args.results_file
if results_filename:
filename_no_ext, ext = os.path.splitext(results_filename)[-1]
filename_no_ext, ext = os.path.splitext(results_filename)
if ext and ext in _FMT_EXT.values():
# if filename provided with one of expected ext,
# remove it as it will be added back
Expand All @@ -308,14 +344,16 @@ def main():
save_results(df, results_filename, fmt)

print(f'--result')
print(json.dumps(dict(filename=results_filename)))
print(df.set_index(args.filename_col).to_json(orient='index', indent=4))


def save_results(df, results_filename, results_format='csv', filename_col='filename'):
results_filename += _FMT_EXT[results_format]
if results_format == 'parquet':
df.set_index(filename_col).to_parquet(results_filename)
elif results_format == 'json':
df.set_index(filename_col).to_json(results_filename, indent=4, orient='index')
elif results_format == 'json-records':
df.to_json(results_filename, lines=True, orient='records')
elif results_format == 'json-split':
df.to_json(results_filename, indent=4, orient='split', index=False)
Expand Down
2 changes: 2 additions & 0 deletions timm/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .dataset_info import DatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .readers import create_reader
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 0f2803d

Please sign in to comment.