修改:增加众数判定

This commit is contained in:
leaf 2022-06-29 09:39:50 +08:00
parent 8bfe21e11b
commit 96c6039474
2 changed files with 42 additions and 16 deletions

BIN
CNN.pkl

Binary file not shown.

58
ai.py
View File

@ -14,6 +14,7 @@ import torch.nn as nn
import numpy as np import numpy as np
import shutil import shutil
import math import math
from scipy import stats
from os.path import exists from os.path import exists
from os import mkdir from os import mkdir
from pathlib import Path from pathlib import Path
@ -27,6 +28,13 @@ def rotate(angle, x, y, point_x, point_y):
return px, py return px, py
# 归一化
def normalize(x):
max_x = np.max(x)
min_x = np.min(x)
return (x-min_x)/(max_x-min_x)
class CNN(nn.Module): class CNN(nn.Module):
def __init__(self, m): def __init__(self, m):
super(CNN, self).__init__() super(CNN, self).__init__()
@ -45,9 +53,9 @@ class CNN(nn.Module):
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2), nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d(3), nn.MaxPool2d(2),
) )
self.med = nn.Linear(32 * 7 * 1, 500) self.med = nn.Linear(32 * 11 * 2, 500)
self.med2 = nn.Linear(1*21*3, 100) self.med2 = nn.Linear(1*21*3, 100)
self.med3 = nn.Linear(100, 500) self.med3 = nn.Linear(100, 500)
self.out = nn.Linear(500, m) # fully connected layer, output 10 classes self.out = nn.Linear(500, m) # fully connected layer, output 10 classes
@ -131,7 +139,7 @@ class HandDetector:
px, py = int(lm.x * w), int(lm.y * h) px, py = int(lm.x * w), int(lm.y * h)
x_list.append(px) x_list.append(px)
y_list.append(py) y_list.append(py)
self.lmList.append([lm.x, lm.y, lm.z]) self.lmList.append([lm.x, lm.y, 0])
if draw: if draw:
cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED) cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED)
x_min, x_max = min(x_list), max(x_list) x_min, x_max = min(x_list), max(x_list)
@ -147,6 +155,10 @@ class HandDetector:
(0, 255, 0), 2) (0, 255, 0), 2)
self.revolve(img) self.revolve(img)
self.re_lmList = np.array(self.re_lmList)
if self.re_lmList.any():
self.re_lmList = np.concatenate((np.zeros((21, 1)), self.re_lmList), axis=1)
self.re_lmList = np.concatenate((self.re_lmList, np.zeros((1, 4))), axis=0)
return self.re_lmList, bbox_info return self.re_lmList, bbox_info
@ -176,11 +188,18 @@ class HandDetector:
theta = theta + math.pi theta = theta + math.pi
# print(theta*180/math.pi) # print(theta*180/math.pi)
for i in self.lmList: for i in self.lmList:
x, y = rotate(theta, i[0], i[1], point_x, point_y) px, py = rotate(theta, i[0] * w, i[1] * h, point_x * w, point_y * h)
px, py = int(x * w), int(y * h) self.re_lmList.append([px, py, 0])
self.re_lmList.append([x, y, i[2]])
if draw: if draw:
cv2.circle(img, (px, py), 5, (0, 0, 255), cv2.FILLED) cv2.circle(img, (int(px), int(py)), 5, (0, 0, 255), cv2.FILLED)
# 归一化
x_array = normalize(np.array(self.re_lmList)[:, 0])
# print(x_array)
for i in range(len(x_array)):
self.re_lmList[i][0] = x_array[i]
y_array = normalize(np.array(self.re_lmList)[:, 1])
for i in range(len(y_array)):
self.re_lmList[i][1] = x_array[i]
else: else:
self.re_lmList = self.lmList self.re_lmList = self.lmList
return self.re_lmList return self.re_lmList
@ -263,6 +282,8 @@ class Main:
self.camera = None self.camera = None
self.detector = HandDetector() self.detector = HandDetector()
self.default_datasets = "Datasets" self.default_datasets = "Datasets"
self.len_x = 22
self.len_y = 4
def make_datasets(self, datasets_dir="default", n=100): def make_datasets(self, datasets_dir="default", n=100):
if datasets_dir == "default": if datasets_dir == "default":
@ -276,24 +297,23 @@ class Main:
self.camera.set(4, 720) self.camera.set(4, 720)
label = input("label:") label = input("label:")
while not label == "": while not label == "":
data = np.zeros([n, 21, 3]) data = np.zeros([n, self.len_x, self.len_y])
shape_list = np.zeros([n, 2], dtype=np.int16) shape_list = np.zeros([n, 2], dtype=np.int16)
hand_type = np.zeros(n, dtype=np.int8) hand_type = np.zeros(n, dtype=np.int8)
zero_data = np.zeros([21, 3])
count = 0 count = 0
cv2.startWindowThread() cv2.startWindowThread()
while True: while True:
frame, img = self.camera.read() frame, img = self.camera.read()
img = self.detector.find_hands(img) img = self.detector.find_hands(img)
result = np.zeros((21, 3)) result = np.zeros((self.len_x, self.len_y))
lm_list, bbox = self.detector.find_position(img) lm_list, bbox = self.detector.find_position(img)
for i in range(len(lm_list)): for i in range(len(lm_list)):
result[i] = np.array(lm_list[i]) result[i] = np.array(lm_list[i])
shape = bbox["shape"] if result.sum() > 0: # 假设矩阵不为0即捕捉到手部时
if result.all() != zero_data.all(): # 假设矩阵不为0即捕捉到手部时 shape = bbox["shape"]
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1] x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
data[count] = result data[count] = result
hand_type[count] = self.detector.hand_type() hand_type[count] = self.detector.hand_type()
@ -332,20 +352,26 @@ class Main:
self.detector = HandDetector() self.detector = HandDetector()
cnn = torch.load("CNN.pkl") cnn = torch.load("CNN.pkl")
out_label = cnn.out_label out_label = cnn.out_label
result = []
disp = ""
while True: while True:
frame, img = self.camera.read() frame, img = self.camera.read()
img = self.detector.find_hands(img) img = self.detector.find_hands(img)
lm_list, bbox = self.detector.find_position(img) lm_list, bbox = self.detector.find_position(img)
if lm_list: if lm_list.any():
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1] x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
data = torch.Tensor(lm_list) data = torch.Tensor(lm_list)
data = data.unsqueeze(0) data = data.unsqueeze(0)
data = data.unsqueeze(0) data = data.unsqueeze(0)
test_output = cnn(data) test_output = cnn(data)
result = torch.max(test_output, 1)[1].data.cpu().numpy()[0] result.append(torch.max(test_output, 1)[1].data.cpu().numpy()[0])
cv2.putText(img, str(out_label[result]), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3, if len(result) > 5:
disp = str(out_label[stats.mode(result)[0][0]])
result = []
cv2.putText(img, disp, (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3) (0, 0, 255), 3)
cv2.imshow("camera", img) cv2.imshow("camera", img)
@ -359,6 +385,6 @@ class Main:
if __name__ == '__main__': if __name__ == '__main__':
solution = Main() solution = Main()
my_datasets_dir = "test" my_datasets_dir = "test"
solution.make_datasets(my_datasets_dir, 50) solution.make_datasets(my_datasets_dir, 100)
solution.train(my_datasets_dir) solution.train(my_datasets_dir)
solution.gesture_recognition() solution.gesture_recognition()