增加UI
优化(zj)
This commit is contained in:
parent
96c6039474
commit
717f9fe606
47
ai.py
47
ai.py
@ -12,6 +12,7 @@ import mediapipe as mp
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tkinter as tk
|
||||||
import shutil
|
import shutil
|
||||||
import math
|
import math
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
@ -260,17 +261,16 @@ class AI:
|
|||||||
optimizer.zero_grad() # clear gradients for this training step
|
optimizer.zero_grad() # clear gradients for this training step
|
||||||
loss.backward() # backpropagation, compute gradients
|
loss.backward() # backpropagation, compute gradients
|
||||||
optimizer.step() # apply gradients
|
optimizer.step() # apply gradients
|
||||||
if (step + 1) % 100 == 0: # 输出结果
|
if (step + 1) % 50 == 0: # 输出结果
|
||||||
if (step + 1) % 100 == 0: # 输出结果
|
print(
|
||||||
print(
|
"\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]"
|
||||||
"\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]"
|
% (
|
||||||
% (
|
epoch + 1,
|
||||||
epoch,
|
(step + 1) * len(data),
|
||||||
step * len(data),
|
len(self.train_loader.dataset),
|
||||||
len(self.train_loader.dataset),
|
100. * (step + 1) / len(self.train_loader),
|
||||||
100. * step / len(self.train_loader),
|
loss.item()
|
||||||
loss.item()
|
), end="")
|
||||||
), end="")
|
|
||||||
|
|
||||||
cnn.out_label = self.out_label
|
cnn.out_label = self.out_label
|
||||||
torch.save(cnn, 'CNN.pkl')
|
torch.save(cnn, 'CNN.pkl')
|
||||||
@ -284,6 +284,20 @@ class Main:
|
|||||||
self.default_datasets = "Datasets"
|
self.default_datasets = "Datasets"
|
||||||
self.len_x = 22
|
self.len_x = 22
|
||||||
self.len_y = 4
|
self.len_y = 4
|
||||||
|
self.label = ''
|
||||||
|
self.top1 = tk.Tk()
|
||||||
|
self.top1.geometry('300x50')
|
||||||
|
self.top1.title('请输入标签')
|
||||||
|
tk.Label(self.top1, text='Label:').place(x=27, y=10)
|
||||||
|
self.entry = tk.Entry(self.top1, width=15)
|
||||||
|
self.entry.place(x=80, y=10)
|
||||||
|
tk.Button(self.top1, text='确定', command=self.change_state).place(x=235, y=5)
|
||||||
|
|
||||||
|
def change_state(self):
|
||||||
|
self.label = self.entry.get() # 调用get()方法,将Entry中的内容获取出来
|
||||||
|
self.top1.quit()
|
||||||
|
if self.label == "":
|
||||||
|
self.top1.destroy()
|
||||||
|
|
||||||
def make_datasets(self, datasets_dir="default", n=100):
|
def make_datasets(self, datasets_dir="default", n=100):
|
||||||
if datasets_dir == "default":
|
if datasets_dir == "default":
|
||||||
@ -295,8 +309,8 @@ class Main:
|
|||||||
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
||||||
self.camera.set(3, 1280)
|
self.camera.set(3, 1280)
|
||||||
self.camera.set(4, 720)
|
self.camera.set(4, 720)
|
||||||
label = input("label:")
|
self.top1.mainloop()
|
||||||
while not label == "":
|
while not self.label == "":
|
||||||
data = np.zeros([n, self.len_x, self.len_y])
|
data = np.zeros([n, self.len_x, self.len_y])
|
||||||
shape_list = np.zeros([n, 2], dtype=np.int16)
|
shape_list = np.zeros([n, 2], dtype=np.int16)
|
||||||
hand_type = np.zeros(n, dtype=np.int8)
|
hand_type = np.zeros(n, dtype=np.int8)
|
||||||
@ -332,10 +346,11 @@ class Main:
|
|||||||
break
|
break
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
open(datasets_dir + "/" + label + ".npz", "w")
|
open(datasets_dir + "/" + self.label + ".npz", "w")
|
||||||
np.savez(datasets_dir + "/" + label + ".npz", label=label, data=data,
|
np.savez(datasets_dir + "/" + self.label + ".npz", label=self.label, data=data,
|
||||||
handtype=hand_type, shape=shape_list)
|
handtype=hand_type, shape=shape_list)
|
||||||
label = input("label:")
|
|
||||||
|
self.top1.mainloop()
|
||||||
|
|
||||||
def train(self, datasets_dir="default"):
|
def train(self, datasets_dir="default"):
|
||||||
if datasets_dir == "default":
|
if datasets_dir == "default":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user