优化(zj)
This commit is contained in:
leaf 2022-06-29 19:35:16 +08:00
parent 96c6039474
commit 717f9fe606

35
ai.py
View File

@ -12,6 +12,7 @@ import mediapipe as mp
import torch
import torch.nn as nn
import numpy as np
import tkinter as tk
import shutil
import math
from scipy import stats
@ -260,15 +261,14 @@ class AI:
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if (step + 1) % 100 == 0: # 输出结果
if (step + 1) % 100 == 0: # 输出结果
if (step + 1) % 50 == 0: # 输出结果
print(
"\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]"
% (
epoch,
step * len(data),
epoch + 1,
(step + 1) * len(data),
len(self.train_loader.dataset),
100. * step / len(self.train_loader),
100. * (step + 1) / len(self.train_loader),
loss.item()
), end="")
@ -284,6 +284,20 @@ class Main:
self.default_datasets = "Datasets"
self.len_x = 22
self.len_y = 4
self.label = ''
self.top1 = tk.Tk()
self.top1.geometry('300x50')
self.top1.title('请输入标签')
tk.Label(self.top1, text='Label:').place(x=27, y=10)
self.entry = tk.Entry(self.top1, width=15)
self.entry.place(x=80, y=10)
tk.Button(self.top1, text='确定', command=self.change_state).place(x=235, y=5)
def change_state(self):
self.label = self.entry.get() # 调用get()方法将Entry中的内容获取出来
self.top1.quit()
if self.label == "":
self.top1.destroy()
def make_datasets(self, datasets_dir="default", n=100):
if datasets_dir == "default":
@ -295,8 +309,8 @@ class Main:
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
self.camera.set(3, 1280)
self.camera.set(4, 720)
label = input("label:")
while not label == "":
self.top1.mainloop()
while not self.label == "":
data = np.zeros([n, self.len_x, self.len_y])
shape_list = np.zeros([n, 2], dtype=np.int16)
hand_type = np.zeros(n, dtype=np.int8)
@ -332,10 +346,11 @@ class Main:
break
cv2.destroyAllWindows()
open(datasets_dir + "/" + label + ".npz", "w")
np.savez(datasets_dir + "/" + label + ".npz", label=label, data=data,
open(datasets_dir + "/" + self.label + ".npz", "w")
np.savez(datasets_dir + "/" + self.label + ".npz", label=self.label, data=data,
handtype=hand_type, shape=shape_list)
label = input("label:")
self.top1.mainloop()
def train(self, datasets_dir="default"):
if datasets_dir == "default":