优化(zj)
This commit is contained in:
leaf 2022-06-29 19:35:16 +08:00
parent 96c6039474
commit 717f9fe606

47
ai.py
View File

@ -12,6 +12,7 @@ import mediapipe as mp
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import tkinter as tk
import shutil import shutil
import math import math
from scipy import stats from scipy import stats
@ -260,17 +261,16 @@ class AI:
optimizer.zero_grad() # clear gradients for this training step optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients optimizer.step() # apply gradients
if (step + 1) % 100 == 0: # 输出结果 if (step + 1) % 50 == 0: # 输出结果
if (step + 1) % 100 == 0: # 输出结果 print(
print( "\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]"
"\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]" % (
% ( epoch + 1,
epoch, (step + 1) * len(data),
step * len(data), len(self.train_loader.dataset),
len(self.train_loader.dataset), 100. * (step + 1) / len(self.train_loader),
100. * step / len(self.train_loader), loss.item()
loss.item() ), end="")
), end="")
cnn.out_label = self.out_label cnn.out_label = self.out_label
torch.save(cnn, 'CNN.pkl') torch.save(cnn, 'CNN.pkl')
@ -284,6 +284,20 @@ class Main:
self.default_datasets = "Datasets" self.default_datasets = "Datasets"
self.len_x = 22 self.len_x = 22
self.len_y = 4 self.len_y = 4
self.label = ''
self.top1 = tk.Tk()
self.top1.geometry('300x50')
self.top1.title('请输入标签')
tk.Label(self.top1, text='Label:').place(x=27, y=10)
self.entry = tk.Entry(self.top1, width=15)
self.entry.place(x=80, y=10)
tk.Button(self.top1, text='确定', command=self.change_state).place(x=235, y=5)
def change_state(self):
self.label = self.entry.get() # 调用get()方法将Entry中的内容获取出来
self.top1.quit()
if self.label == "":
self.top1.destroy()
def make_datasets(self, datasets_dir="default", n=100): def make_datasets(self, datasets_dir="default", n=100):
if datasets_dir == "default": if datasets_dir == "default":
@ -295,8 +309,8 @@ class Main:
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW) self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
self.camera.set(3, 1280) self.camera.set(3, 1280)
self.camera.set(4, 720) self.camera.set(4, 720)
label = input("label:") self.top1.mainloop()
while not label == "": while not self.label == "":
data = np.zeros([n, self.len_x, self.len_y]) data = np.zeros([n, self.len_x, self.len_y])
shape_list = np.zeros([n, 2], dtype=np.int16) shape_list = np.zeros([n, 2], dtype=np.int16)
hand_type = np.zeros(n, dtype=np.int8) hand_type = np.zeros(n, dtype=np.int8)
@ -332,10 +346,11 @@ class Main:
break break
cv2.destroyAllWindows() cv2.destroyAllWindows()
open(datasets_dir + "/" + label + ".npz", "w") open(datasets_dir + "/" + self.label + ".npz", "w")
np.savez(datasets_dir + "/" + label + ".npz", label=label, data=data, np.savez(datasets_dir + "/" + self.label + ".npz", label=self.label, data=data,
handtype=hand_type, shape=shape_list) handtype=hand_type, shape=shape_list)
label = input("label:")
self.top1.mainloop()
def train(self, datasets_dir="default"): def train(self, datasets_dir="default"):
if datasets_dir == "default": if datasets_dir == "default":