car、cat、flower,每个种别200张,图像:(64,64,3)

57.人工智能——图像分类:自定义演习、评估、猜测过程_类别_模子 文字写作

car

cat

flower

二、网络模型准备

数据演习前,须要准备网络模型,这里直策应用paddle中内置的模型resnet18

#选择内置模型model=paddle.vision.models.resnet18(num_classes=num_classes)三、自定义演习和评估过程

#选择内置模型model=paddle.vision.models.resnet18(num_classes=num_classes)#设置模型参数opt=paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters())epochs=20bestacc=0 model.train() #开启演习模式train_loader=paddle.io.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)val_loader=paddle.io.DataLoader(val_dataset,batch_size=batch_size,shuffle=False)for epoch in range(epochs): for i,(data,label) in enumerate(train_loader()): pred=model(data) loss=paddle.nn.functional.cross_entropy(pred,label) avg_loss=paddle.mean(loss) if i % 10==0: print(f"epoch:{epoch},batch:{i},loss:{avg_loss.numpy()}") avg_loss.backward() opt.step() opt.clear_grad() model.eval() accs=[] losses=[] for i ,(data,label) in enumerate(val_loader()): label=paddle.reshape(label,(-1,1)) pred=model(data) loss=paddle.nn.functional.cross_entropy(pred,label) acc=paddle.metric.accuracy(pred,label) accs.append(acc.numpy()) losses.append(loss.numpy()) print(f"epoch:{epoch},loss:{np.mean(losses)},acc:{np.mean(accs)}") if np.mean(accs)>bestacc: bestacc=np.mean(accs) print(f"save best model,acc:{bestacc},epoch:{epoch}") paddle.save(model.state_dict(),"models/best.pdparams") model.train() #返回演习模式

#演习与评估过程数据epoch:0,batch:0,loss:[1.3691025]epoch:0,batch:10,loss:[0.46607488]epoch:0,loss:5.3897786140441895,acc:0.4791666865348816save best model,acc:0.4791666865348816,epoch:0epoch:1,batch:0,loss:[0.8707124]epoch:1,batch:10,loss:[0.96523666]epoch:1,loss:0.6027684211730957,acc:0.8229166865348816save best model,acc:0.8229166865348816,epoch:1epoch:2,batch:0,loss:[0.42904758]epoch:2,batch:10,loss:[0.32614255]epoch:2,loss:0.5535903573036194,acc:0.8489583134651184save best model,acc:0.8489583134651184,epoch:2epoch:3,batch:0,loss:[0.3938939]epoch:3,batch:10,loss:[0.3688711]epoch:3,loss:0.3994542360305786,acc:0.8880208134651184save best model,acc:0.8880208134651184,epoch:3epoch:4,batch:0,loss:[0.22923152]epoch:4,batch:10,loss:[0.29723316]epoch:4,loss:0.4439505338668823,acc:0.8098958134651184epoch:5,batch:0,loss:[0.06217474]epoch:5,batch:10,loss:[0.32871988]epoch:5,loss:0.5640340447425842,acc:0.7942708134651184epoch:6,batch:0,loss:[0.42060828]epoch:6,batch:10,loss:[0.18327941]epoch:6,loss:0.5131557583808899,acc:0.8385416865348816epoch:7,batch:0,loss:[0.3300568]epoch:7,batch:10,loss:[0.36528504]epoch:7,loss:0.44314220547676086,acc:0.8515625epoch:8,batch:0,loss:[0.19298445]epoch:8,batch:10,loss:[0.10667433]epoch:8,loss:0.794183075428009,acc:0.7734375epoch:9,batch:0,loss:[0.16208875]epoch:9,batch:10,loss:[0.21965706]epoch:9,loss:0.4500595033168793,acc:0.84375epoch:10,batch:0,loss:[0.1167096]epoch:10,batch:10,loss:[0.18035632]epoch:10,loss:0.4533429443836212,acc:0.8489583134651184epoch:11,batch:0,loss:[0.05396715]epoch:11,batch:10,loss:[0.5844509]epoch:11,loss:0.4798758625984192,acc:0.8541666865348816epoch:12,batch:0,loss:[0.38118595]epoch:12,batch:10,loss:[0.23948127]epoch:12,loss:0.3134278655052185,acc:0.9348958134651184save best model,acc:0.9348958134651184,epoch:12epoch:13,batch:0,loss:[0.57383955]epoch:13,batch:10,loss:[0.30720592]epoch:13,loss:0.5933756828308105,acc:0.8255208134651184epoch:14,batch:0,loss:[0.25598395]epoch:14,batch:10,loss:[0.18520543]epoch:14,loss:0.3323279917240143,acc:0.8697916865348816epoch:15,batch:0,loss:[0.05330365]epoch:15,batch:10,loss:[0.22565426]epoch:15,loss:0.4196951687335968,acc:0.8671875epoch:16,batch:0,loss:[0.19039027]epoch:16,batch:10,loss:[0.36966798]epoch:16,loss:0.5373609662055969,acc:0.859375epoch:17,batch:0,loss:[0.06061709]epoch:17,batch:10,loss:[0.09108147]epoch:17,loss:0.3177206218242645,acc:0.921875epoch:18,batch:0,loss:[0.1017945]epoch:18,batch:10,loss:[0.09792318]epoch:18,loss:0.4859660267829895,acc:0.859375epoch:19,batch:0,loss:[0.03445277]epoch:19,batch:10,loss:[0.05108456]epoch:19,loss:0.4685693681240082,acc:0.8828125

从上面数据看,最好的模型在第12轮,准确率达94.5%。
这里要把稳一下,过早得到最佳模型参数并不好。

四、模型预测

在上一文中,也有模型预测,这里加了一个得到预测最大值的索引函数:argmax,

#加载模型model=networkmodel_dict=paddle.load("models/best.pdparams")model.load_dict(model_dict)model.eval()idx=np.random.randint(0,len(test_set_x))print("测试数据id:",idx)img=test_dataset[idx][0]real_label=test_dataset[idx][1]print(img.shape)img=img.reshape((1,img.shape[0],img.shape[1],img.shape[2]))results=model(paddle.to_tensor(img))#print(results)predictlabel1=np.argsort(results.numpy())[0][-1] #最大值的索引,用argsortpredictlabel=np.argmax(results.numpy()) #最大值的索引,用argmax#print(predictlabel1,predictlabel)print("预测结果:",classes[predictlabel].decode("utf-8"),"实际种别:",classes[real_label].decode("utf-8"))

随机几条测试数据,运行结果:

测试数据id: 6 [3, 64, 64] 预测结果: flower 实际种别: flower

测试数据id: 20 [3, 64, 64] 预测结果: car 实际种别: car

测试数据id: 56 [3, 64, 64] 预测结果: car 实际种别: car

测试数据id: 58 [3, 64, 64] 预测结果: cat 实际种别: cat

测试数据id: 19 [3, 64, 64] 预测结果: flower 实际种别: flower