diff --git a/ai.py b/ai.py index f2821ad..5bcd6f2 100644 --- a/ai.py +++ b/ai.py @@ -12,6 +12,7 @@ import mediapipe as mp import torch import torch.nn as nn import numpy as np +import tkinter as tk import shutil import math from scipy import stats @@ -260,17 +261,16 @@ class AI: optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients - if (step + 1) % 100 == 0: # 输出结果 - if (step + 1) % 100 == 0: # 输出结果 - print( - "\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]" - % ( - epoch, - step * len(data), - len(self.train_loader.dataset), - 100. * step / len(self.train_loader), - loss.item() - ), end="") + if (step + 1) % 50 == 0: # 输出结果 + print( + "\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]" + % ( + epoch + 1, + (step + 1) * len(data), + len(self.train_loader.dataset), + 100. * (step + 1) / len(self.train_loader), + loss.item() + ), end="") cnn.out_label = self.out_label torch.save(cnn, 'CNN.pkl') @@ -284,6 +284,20 @@ class Main: self.default_datasets = "Datasets" self.len_x = 22 self.len_y = 4 + self.label = '' + self.top1 = tk.Tk() + self.top1.geometry('300x50') + self.top1.title('请输入标签') + tk.Label(self.top1, text='Label:').place(x=27, y=10) + self.entry = tk.Entry(self.top1, width=15) + self.entry.place(x=80, y=10) + tk.Button(self.top1, text='确定', command=self.change_state).place(x=235, y=5) + + def change_state(self): + self.label = self.entry.get() # 调用get()方法,将Entry中的内容获取出来 + self.top1.quit() + if self.label == "": + self.top1.destroy() def make_datasets(self, datasets_dir="default", n=100): if datasets_dir == "default": @@ -295,8 +309,8 @@ class Main: self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW) self.camera.set(3, 1280) self.camera.set(4, 720) - label = input("label:") - while not label == "": + self.top1.mainloop() + while not self.label == "": data = np.zeros([n, self.len_x, self.len_y]) shape_list = np.zeros([n, 2], dtype=np.int16) hand_type = np.zeros(n, dtype=np.int8) @@ -332,10 +346,11 @@ class Main: break cv2.destroyAllWindows() - open(datasets_dir + "/" + label + ".npz", "w") - np.savez(datasets_dir + "/" + label + ".npz", label=label, data=data, + open(datasets_dir + "/" + self.label + ".npz", "w") + np.savez(datasets_dir + "/" + self.label + ".npz", label=self.label, data=data, handtype=hand_type, shape=shape_list) - label = input("label:") + + self.top1.mainloop() def train(self, datasets_dir="default"): if datasets_dir == "default":