博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
ResNet 修改
阅读量:6296 次
发布时间:2019-06-22

本文共 17617 字,大约阅读时间需要 58 分钟。

apache 开源项目

修改如下:

训练模块

import argparse,logging,osimport mxnet as mxfrom symbol_resnet import resnetlogger = logging.getLogger()logger.setLevel(logging.INFO)def multi_factor_scheduler(begin_epoch, epoch_size, step=[60, 75, 90], factor=0.1):    step_ = [epoch_size * (x-begin_epoch) for x in step if x-begin_epoch > 0]    return mx.lr_scheduler.MultiFactorScheduler(step=step_, factor=factor) if len(step_) else Nonedef main():    if args.data_type == "cifar10":        args.aug_level = 1        args.num_classes = 10        # depth should be one of 110, 164, 1001,...,which is should fit (args.depth-2)%9 == 0        if((args.depth-2)%9 == 0 and args.depth >= 164):            per_unit = [(args.depth-2)/9]            filter_list = [16, 64, 128, 256]            bottle_neck = True        elif((args.depth-2)%6 == 0 and args.depth < 164):            per_unit = [(args.depth-2)/6]            filter_list = [16, 16, 32, 64]            bottle_neck = False        else:            raise ValueError("no experiments done on detph {}, you can do it youself".format(args.depth))        units = per_unit*3        symbol = resnet(units=units, num_stage=3, filter_list=filter_list, num_class=args.num_classes,                        data_type="cifar10", bottle_neck = bottle_neck, bn_mom=args.bn_mom, workspace=args.workspace,                        memonger=args.memonger)    elif args.data_type == "imagenet":        args.num_classes = 3        if args.depth == 18:            units = [2, 2, 2, 2]        elif args.depth == 34:            units = [3, 4, 6, 3]        elif args.depth == 50:            units = [3, 4, 6, 3]        elif args.depth == 101:            units = [3, 4, 23, 3]        elif args.depth == 152:            units = [3, 8, 36, 3]        elif args.depth == 200:            units = [3, 24, 36, 3]        elif args.depth == 269:            units = [3, 30, 48, 8]        else:            raise ValueError("no experiments done on detph {}, you can do it youself".format(args.depth))        symbol = resnet(units=units, num_stage=4, filter_list=[64, 256, 512, 1024, 2048] if args.depth >=50                        else [64, 64, 128, 256, 512], num_class=args.num_classes, data_type="imagenet", bottle_neck = True                        if args.depth >= 50 else False, bn_mom=args.bn_mom, workspace=args.workspace,                        memonger=args.memonger)    else:         raise ValueError("do not support {} yet".format(args.data_type))    kv = mx.kvstore.create(args.kv_store)    devs = mx.cpu() if args.gpus is None else [mx.gpu(int(i)) for i in args.gpus.split(',')]    epoch_size = max(int(args.num_examples / args.batch_size / kv.num_workers), 1)    begin_epoch = args.model_load_epoch if args.model_load_epoch else 0    if not os.path.exists("./model"):        os.mkdir("./model")    model_prefix = "model/resnet-{}-{}-{}".format(args.data_type, args.depth, kv.rank)    checkpoint = mx.callback.do_checkpoint(model_prefix)    arg_params = None    aux_params = None    if args.retrain:        _, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.model_load_epoch)    if args.memonger:        import memonger        symbol = memonger.search_plan(symbol, data=(args.batch_size, 3, 32, 32) if args.data_type=="cifar10"                                                    else (args.batch_size, 3, 128, 128))    train = mx.io.ImageRecordIter(        path_imgrec         = os.path.join(args.data_dir, "cifar10_train.rec") if args.data_type == 'cifar10' else                              os.path.join(args.data_dir, "train_256_q90.rec") if args.aug_level == 1                              else os.path.join(args.data_dir, "train_480_q90.rec"),        label_width         = 1,        data_name           = 'data',        label_name          = 'softmax_label',        data_shape          = (3, 32, 32) if args.data_type=="cifar10" else (3, 128, 128),        batch_size          = args.batch_size,        pad                 = 4 if args.data_type == "cifar10" else 0,        fill_value          = 127,  # only used when pad is valid        rand_crop           = True,        max_random_scale    = 1.0,  # 480 with imagnet, 32 with cifar10        min_random_scale    = 1.0 if args.data_type == "cifar10" else 1.0 if args.aug_level == 1 else 0.533,  # 256.0/480.0        max_aspect_ratio    = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 0.25,        random_h            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 36,  # 0.4*90        random_s            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50,  # 0.4*127        random_l            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50,  # 0.4*127        max_rotate_angle    = 0 if args.aug_level <= 2 else 10,        max_shear_ratio     = 0 if args.aug_level <= 2 else 0.1,        rand_mirror         = True,        shuffle             = True,        num_parts           = kv.num_workers,        part_index          = kv.rank)    val = mx.io.ImageRecordIter(        path_imgrec         = os.path.join(args.data_dir, "cifar10_val.rec") if args.data_type == 'cifar10' else                              os.path.join(args.data_dir, "val_256_q90.rec"),        label_width         = 1,        data_name           = 'data',        label_name          = 'softmax_label',        batch_size          = args.batch_size,        data_shape          = (3, 32, 32) if args.data_type=="cifar10" else (3, 128, 128),        rand_crop           = False,        rand_mirror         = False,        num_parts           = kv.num_workers,        part_index          = kv.rank)    model = mx.model.FeedForward(        ctx                 = devs,        symbol              = symbol,        arg_params          = arg_params,        aux_params          = aux_params,        num_epoch           = 200 if args.data_type == "cifar10" else 120,        begin_epoch         = begin_epoch,        learning_rate       = args.lr,        momentum            = args.mom,        wd                  = args.wd,        optimizer           = 'nag',        # optimizer          = 'sgd',        initializer         = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),        lr_scheduler        = multi_factor_scheduler(begin_epoch, epoch_size, step=[120, 160], factor=0.1)                             if args.data_type=='cifar10' else                             multi_factor_scheduler(begin_epoch, epoch_size, step=[30, 60, 90], factor=0.1),        )    model.fit(        X                  = train,        eval_data          = val,        eval_metric        = ['acc', 'ce'] if args.data_type=='cifar10' else                             ['acc','ce', mx.metric.create('top_k_accuracy', top_k = 5)],        kvstore            = kv,        batch_end_callback = mx.callback.Speedometer(args.batch_size, args.frequent),        epoch_end_callback = checkpoint)    # logging.info("top-1 and top-5 acc is {}".format(model.score(X = val,    #               eval_metric = ['acc', mx.metric.create('top_k_accuracy', top_k = 5)])))if __name__ == "__main__":    parser = argparse.ArgumentParser(description="command for training resnet-v2")    parser.add_argument('--gpus', type=str, default='0', help='the gpus will be used, e.g "0,1,2,3"')    parser.add_argument('--data-dir', type=str, default='./data/imagenet/', help='the input data directory')    parser.add_argument('--data-type', type=str, default='imagenet', help='the dataset type')    parser.add_argument('--list-dir', type=str, default='./',                        help='the directory which contain the training list file')    parser.add_argument('--lr', type=float, default=0.1, help='initialization learning reate')    parser.add_argument('--mom', type=float, default=0.9, help='momentum for sgd')    parser.add_argument('--bn-mom', type=float, default=0.9, help='momentum for batch normlization')    parser.add_argument('--wd', type=float, default=0.0001, help='weight decay for sgd')    parser.add_argument('--batch-size', type=int, default=256, help='the batch size')    parser.add_argument('--workspace', type=int, default=512, help='memory space size(MB) used in convolution, if xpu '                        ' memory is oom, then you can try smaller vale, such as --workspace 256')    parser.add_argument('--depth', type=int, default=50, help='the depth of resnet')    parser.add_argument('--num-classes', type=int, default=1000, help='the class number of your task')    parser.add_argument('--aug-level', type=int, default=2, choices=[1, 2, 3],                        help='level 1: use only random crop and random mirror\n'                             'level 2: add scale/aspect/hsv augmentation based on level 1\n'                             'level 3: add rotation/shear augmentation based on level 2')    parser.add_argument('--num-examples', type=int, default=1281167, help='the number of training examples')    parser.add_argument('--kv-store', type=str, default='device', help='the kvstore type')    parser.add_argument('--model-load-epoch', type=int, default=0,                        help='load the model on an epoch using the model-load-prefix')    parser.add_argument('--frequent', type=int, default=50, help='frequency of logging')    parser.add_argument('--memonger', action='store_true', default=False,                        help='true means using memonger to save momory, https://github.com/dmlc/mxnet-memonger')    parser.add_argument('--retrain', action='store_true', default=False, help='true means continue training')    args = parser.parse_args()    logging.info(args)    main()

为减小网络大小,将图片全部缩放为128*128大小,平时使用ResNet-50的网络,将num_classes 改为需要的分类数目。

train acc可以在99.9%水平,val acc 稳定在80%左右

 

 

预测模块

import numpy as npimport cv2import mxnet as mximport argparsedef ch_dev(arg_params, aux_params, ctx):    new_args = dict()    new_auxs = dict()    for k, v in arg_params.items():        new_args[k] = v.as_in_context(ctx)    for k, v in aux_params.items():        new_auxs[k] = v.as_in_context(ctx)    return new_args, new_auxsdef predict(img):    # compute the predict probabilities    mod.forward(Batch([img]))    prob = mod.get_outputs()[0].asnumpy()    # print the top-5    prob = np.squeeze(prob)    a = np.argsort(prob)[::-1]    for i in a[0:3]:        print('probability=%f, class=%s' %(prob[i], labels[i]))def main():    synset = [l.strip() for l in open(args.synset).readlines()]    # 添加预测    ctx = mx.gpu(args.gpu)    sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)    mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)    mod.bind(for_training=False, data_shapes=[('data', (1,3,128,128))],label_shapes=mod._label_shapes)    mod.set_params(arg_params, aux_params, allow_missing=True)        from collections import namedtuple    Batch = namedtuple('Batch', ['data'])        if args.lst:        file = open('instances_test.lst')        for line in file:            src = ""            for i in range(len(line)-1,0,-1):                if line[i] == '\t':                    break                src += line[i]            src = src[::-1]            src = "/mnt/hdfs-data-4/data/jian.yin/ped_thumbnail/instances_test/" + src            print(src[0:-1])                        # convert into format (batch, RGB, width, height)            img = mx.image.imdecode(open(src[0:-1],'rb').read())            img = mx.image.imresize(img, 128, 128) # resize            img = img.transpose((2, 0, 1)) # Channel first            img = img.expand_dims(axis=0) # batchify            img = img.astype('float32') # for gpu context                        mod.forward(Batch([img]))            prob = mod.get_outputs()[0].asnumpy()            # print the top-3            prob = np.squeeze(prob)            a = np.argsort(prob)[::-1]            for i in a[0:3]:                print('probability=%f, class=%s' %(prob[i], synset[i]))            #             img = cv2.cvtColor(cv2.imread(src[0:-1]), cv2.COLOR_BGR2RGB)#             img = cv2.resize(img, (128, 128))  # resize to 224*224 to fit model#             img = np.swapaxes(img, 0, 2)#             img = np.swapaxes(img, 1, 2)  # change to (c, h,w) order#             img = img[np.newaxis, :]  # extend to (n, c, h, w)#             ctx = mx.gpu(args.gpu)#             sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)#             arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)#             arg_params["data"] = mx.nd.array(img, ctx)#             arg_params["softmax_label"] = mx.nd.empty((1,), ctx)#             exe = sym.bind(ctx, arg_params ,args_grad=None, grad_req="null", aux_states=aux_params)#             exe.forward(is_train=False)#             prob = np.squeeze(exe.outputs[0].asnumpy())#             pred = np.argsort(prob)[::-1]#             print("Top1 result is: ", synset[pred[0]])#             # print("Top5 result is: ", [synset[pred[i]] for i in range(5)])        file.close()    if __name__ == "__main__":    parser = argparse.ArgumentParser(description="use pre-trainned resnet model to classify one image")    parser.add_argument('--img', type=str, default='test.jpg', help='input image for classification')    # add --lst    parser.add_argument('--lst',type=str,default='test.lst',help="input image's lst for classification")    parser.add_argument('--gpu', type=int, default=0, help='the gpu id used for predict')    parser.add_argument('--synset', type=str, default='synset.txt', help='file mapping class id to class name')    parser.add_argument('--prefix', type=str, default='resnet-50', help='the prefix of the pre-trained model')    parser.add_argument('--epoch', type=int, default=0, help='the epoch of the pre-trained model')    args = parser.parse_args()    main()

添加了--lst可选参数,可以批处理序列化文件预测。

原文预测模块效率较低,改用mxnet标准的predict写法:https://mxnet.incubator.apache.org/tutorials/python/predict_image.html

添加一个脚本,防止忘记一些参数的写法:

#!/usr/bin/python -u predict.py --lst instances_test.lst --prefix resnet-50 --synset ped_thumbnail.txt --gpu 0

记得运行的时候添加管道命令 > 

/mnt/1/385_328_428_402_6.jpgprobability=0.994927, class=1 Cyclistprobability=0.003335, class=2 Othersprobability=0.001739, class=0 Pedestrian/mnt2/439_359_481_428_0.jpgprobability=0.994793, class=2 Othersprobability=0.002817, class=0 Pedestrianprobability=0.002390, class=1 Cyclist/mnt/2/619_337_658_401_16.jpgprobability=0.992218, class=2 Othersprobability=0.007275, class=1 Cyclistprobability=0.000507, class=0 Pedestrian/mnt1/511_288_561_385_1.jpgprobability=0.997837, class=1 Cyclistprobability=0.001525, class=0 Pedestrianprobability=0.000638, class=2 Others

分析预测结果

可以先把各种分类的路径记录下来。

import itertoolsimport numpy as npimport matplotlib.pyplot as pltfrom sklearn import svm, datasetsfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import confusion_matrixfile = open('myPredict.txt')cnt = 0true = []pred = []for line in file:    if cnt%4 == 0:        pos =-1        for i in range(len(line)-1,-1,-1):            if line[i]=='/':                pos = i - 1                break        true.append(int(line[pos]))    if cnt%4 == 1:        pos = -1        for i in range(len(line)-1,-1,-1):            if line[i] == ' ':                pos = i - 1                break        pred.append(int(line[pos]))    cnt+=1print(true)print(pred)print(confusion_matrix(true,pred))file.close()zero_zero = []zero_one = []zero_two = []one_zero = []one_one = []one_two = []two_zero = []two_one = []two_two = []cnt = 0pos = 0file = open('myPredict.txt')for line in file:    if cnt%4==0:        if true[pos] == 0 and pred[pos] == 0:            zero_zero.append(line)        if true[pos] == 0 and pred[pos] == 1:            zero_one.append(line)        if true[pos] == 0 and pred[pos] == 2:            zero_two.append(line)        if true[pos] == 1 and pred[pos] == 0:            one_zero.append(line)        if true[pos] == 1 and pred[pos] == 1:            one_one.append(line)        if true[pos] == 1 and pred[pos] == 2:            one_two.append(line)        if true[pos] == 2 and pred[pos] == 0:            two_zero.append(line)        if true[pos] == 2 and pred[pos] == 1:            two_one.append(line)        if true[pos] == 2 and pred[pos] == 2:            two_two.append(line)        pos+=1    cnt+=1file.close()print(len(zero_one)+len(zero_two)+len(one_zero)+len(one_two)+len(two_zero)+len(two_one))# 0 - 0write_zero_zero = open('zero_zero.txt','w')for i in range(len(zero_zero)):    write_zero_zero.write(zero_zero[i])write_zero_zero.close()# 0 - 1write_zero_one = open('zero_one.txt','w')for i in range(len(zero_one)):    write_zero_one.write(zero_one[i])write_zero_one.close()# 0 - 2write_zero_two = open('zero_two.txt','w')for i in range(len(zero_two)):    write_zero_two.write(zero_two[i])write_zero_two.close()# 1 - 0write_one_zero = open('one_zero.txt','w')for i in range(len(one_zero)):    write_one_zero.write(one_zero[i])write_one_zero.close()# 1 - 1write_one_one = open('one_one.txt','w')for i in range(len(one_one)):    write_one_one.write(one_one[i])write_one_one.close()# 1 - 2write_one_two = open('one_two.txt','w')for i in range(len(one_two)):    write_one_two.write(one_two[i])write_one_two.close()# 2 - 0write_two_zero = open('two_zero.txt','w')for i in range(len(two_zero)):    write_two_zero.write(two_zero[i])write_two_zero.close()# 2 - 1write_two_one = open('two_one.txt','w')for i in range(len(two_one)):    write_two_one.write(two_one[i])write_two_one.close()# 2 - 2write_two_two = open('two_two.txt','w')for i in range(len(two_two)):    write_two_two.write(two_two[i])write_two_two.close()

混淆矩阵如下:

 

转载于:https://www.cnblogs.com/TreeDream/p/10102626.html

你可能感兴趣的文章
解析python数据后用html输出
查看>>
mysql日志分析工具之mysqlsla
查看>>
float 属性详解
查看>>
javascript操作select下拉列表框的一点小经验
查看>>
LeeCode-Majority Element
查看>>
PowerPoint动画制作时的需要注意的N个事项
查看>>
我的友情链接
查看>>
Racktables(一)的资产管理软件安装配置
查看>>
crontab命令
查看>>
阿里云磁盘扩容踩坑总结
查看>>
Java中的String字符串
查看>>
EDM营销中HTML邮件设计方法和技巧
查看>>
shell的最大命令行长度
查看>>
刚搭建的linux环境的基本优化以及优化脚本---菜鸟初写
查看>>
北京20-35万Sharepoint Senior Consultant
查看>>
Vue开源项目库汇总
查看>>
apache 站点安全
查看>>
zabbix监控管理维护脚本
查看>>
使用async/await——Nodejs+ExpressJs+Babel
查看>>
贵州省新农合业务系统容灾技术支撑服务项目
查看>>