Skip to content

Commit

Permalink
Support loading MobileNet from distilled parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 28, 2019
1 parent 246adaa commit 9aace3c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cnn/imagenet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
metavar='NSL', help='Number of structured layer (default 7)')
parser.add_argument('--width', default=1.0, type=float,
metavar='WIDTH', help='Width multiplier of the CNN (default 1.0)')
parser.add_argument('--distilled-param-path', default='', type=str, metavar='PATH',
help='path to distilled parameters (default: none)')
parser.add_argument('--full-model-path', default='', type=str, metavar='PATH',
help='path to full model checkpoint (default: none)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
Expand Down Expand Up @@ -139,6 +143,8 @@ def main():
elif args.arch == 'mobilenetv1_struct':
model = MobileNet(width_mult=args.width, structure=[args.struct] * args.n_struct_layers,
softmax_structure=args.softmax_struct)
if args.distilled_param_path:
model.load_state_dict(model.mixed_model_state_dict(args.full_model_path, args.distilled_param_path))
else:
model = models.__dict__[args.arch]()
if args.local_rank == 0:
Expand Down
15 changes: 15 additions & 0 deletions cnn/mobilenet_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ def forward(self, x):
out = self.linear(out)
return out

def mixed_model_state_dict(self, full_model_path, distilled_param_path):
current_state_dict_keys = self.state_dict().keys()
full_model_state_dict = torch.load(full_model_path, map_location='cpu')['state_dict']
full_model_state_dict = {name.replace('module.', ''): param for name, param in full_model_state_dict.items()}
distilled_params = torch.load(distilled_param_path, map_location='cpu')
state_dict = {name: param for name, param in full_model_state_dict.items() if name in current_state_dict_keys}
for i, struct in enumerate(self.structure):
# Only support butterfly for now
if struct.startswith('odo') or struct.startswith('regular'):
layer = f'layers.{i}.conv2'
nblocks = int(struct.split('_')[1])
structured_param = distilled_params[layer, nblocks]
state_dict.update({layer + '.' + name: param for name, param in structured_param.items()})
return state_dict


def test():
net = MobileNet()
Expand Down

0 comments on commit 9aace3c

Please sign in to comment.