增加UI
优化(zj)
This commit is contained in:
parent
96c6039474
commit
717f9fe606
35
ai.py
35
ai.py
@ -12,6 +12,7 @@ import mediapipe as mp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import tkinter as tk
|
||||
import shutil
|
||||
import math
|
||||
from scipy import stats
|
||||
@ -260,15 +261,14 @@ class AI:
|
||||
optimizer.zero_grad() # clear gradients for this training step
|
||||
loss.backward() # backpropagation, compute gradients
|
||||
optimizer.step() # apply gradients
|
||||
if (step + 1) % 100 == 0: # 输出结果
|
||||
if (step + 1) % 100 == 0: # 输出结果
|
||||
if (step + 1) % 50 == 0: # 输出结果
|
||||
print(
|
||||
"\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]"
|
||||
% (
|
||||
epoch,
|
||||
step * len(data),
|
||||
epoch + 1,
|
||||
(step + 1) * len(data),
|
||||
len(self.train_loader.dataset),
|
||||
100. * step / len(self.train_loader),
|
||||
100. * (step + 1) / len(self.train_loader),
|
||||
loss.item()
|
||||
), end="")
|
||||
|
||||
@ -284,6 +284,20 @@ class Main:
|
||||
self.default_datasets = "Datasets"
|
||||
self.len_x = 22
|
||||
self.len_y = 4
|
||||
self.label = ''
|
||||
self.top1 = tk.Tk()
|
||||
self.top1.geometry('300x50')
|
||||
self.top1.title('请输入标签')
|
||||
tk.Label(self.top1, text='Label:').place(x=27, y=10)
|
||||
self.entry = tk.Entry(self.top1, width=15)
|
||||
self.entry.place(x=80, y=10)
|
||||
tk.Button(self.top1, text='确定', command=self.change_state).place(x=235, y=5)
|
||||
|
||||
def change_state(self):
|
||||
self.label = self.entry.get() # 调用get()方法,将Entry中的内容获取出来
|
||||
self.top1.quit()
|
||||
if self.label == "":
|
||||
self.top1.destroy()
|
||||
|
||||
def make_datasets(self, datasets_dir="default", n=100):
|
||||
if datasets_dir == "default":
|
||||
@ -295,8 +309,8 @@ class Main:
|
||||
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
||||
self.camera.set(3, 1280)
|
||||
self.camera.set(4, 720)
|
||||
label = input("label:")
|
||||
while not label == "":
|
||||
self.top1.mainloop()
|
||||
while not self.label == "":
|
||||
data = np.zeros([n, self.len_x, self.len_y])
|
||||
shape_list = np.zeros([n, 2], dtype=np.int16)
|
||||
hand_type = np.zeros(n, dtype=np.int8)
|
||||
@ -332,10 +346,11 @@ class Main:
|
||||
break
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
open(datasets_dir + "/" + label + ".npz", "w")
|
||||
np.savez(datasets_dir + "/" + label + ".npz", label=label, data=data,
|
||||
open(datasets_dir + "/" + self.label + ".npz", "w")
|
||||
np.savez(datasets_dir + "/" + self.label + ".npz", label=self.label, data=data,
|
||||
handtype=hand_type, shape=shape_list)
|
||||
label = input("label:")
|
||||
|
||||
self.top1.mainloop()
|
||||
|
||||
def train(self, datasets_dir="default"):
|
||||
if datasets_dir == "default":
|
||||
|
Loading…
x
Reference in New Issue
Block a user