Browse Source

修改:增加众数判定

AI
leaf 2 years ago
parent
commit
96c6039474
  1. BIN
      CNN.pkl
  2. 58
      ai.py

BIN
CNN.pkl

Binary file not shown.

58
ai.py

@ -14,6 +14,7 @@ import torch.nn as nn
import numpy as np
import shutil
import math
from scipy import stats
from os.path import exists
from os import mkdir
from pathlib import Path
@ -27,6 +28,13 @@ def rotate(angle, x, y, point_x, point_y):
return px, py
# 归一化
def normalize(x):
max_x = np.max(x)
min_x = np.min(x)
return (x-min_x)/(max_x-min_x)
class CNN(nn.Module):
def __init__(self, m):
super(CNN, self).__init__()
@ -45,9 +53,9 @@ class CNN(nn.Module):
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(3),
nn.MaxPool2d(2),
)
self.med = nn.Linear(32 * 7 * 1, 500)
self.med = nn.Linear(32 * 11 * 2, 500)
self.med2 = nn.Linear(1*21*3, 100)
self.med3 = nn.Linear(100, 500)
self.out = nn.Linear(500, m) # fully connected layer, output 10 classes
@ -131,7 +139,7 @@ class HandDetector:
px, py = int(lm.x * w), int(lm.y * h)
x_list.append(px)
y_list.append(py)
self.lmList.append([lm.x, lm.y, lm.z])
self.lmList.append([lm.x, lm.y, 0])
if draw:
cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED)
x_min, x_max = min(x_list), max(x_list)
@ -147,6 +155,10 @@ class HandDetector:
(0, 255, 0), 2)
self.revolve(img)
self.re_lmList = np.array(self.re_lmList)
if self.re_lmList.any():
self.re_lmList = np.concatenate((np.zeros((21, 1)), self.re_lmList), axis=1)
self.re_lmList = np.concatenate((self.re_lmList, np.zeros((1, 4))), axis=0)
return self.re_lmList, bbox_info
@ -176,11 +188,18 @@ class HandDetector:
theta = theta + math.pi
# print(theta*180/math.pi)
for i in self.lmList:
x, y = rotate(theta, i[0], i[1], point_x, point_y)
px, py = int(x * w), int(y * h)
self.re_lmList.append([x, y, i[2]])
px, py = rotate(theta, i[0] * w, i[1] * h, point_x * w, point_y * h)
self.re_lmList.append([px, py, 0])
if draw:
cv2.circle(img, (px, py), 5, (0, 0, 255), cv2.FILLED)
cv2.circle(img, (int(px), int(py)), 5, (0, 0, 255), cv2.FILLED)
# 归一化
x_array = normalize(np.array(self.re_lmList)[:, 0])
# print(x_array)
for i in range(len(x_array)):
self.re_lmList[i][0] = x_array[i]
y_array = normalize(np.array(self.re_lmList)[:, 1])
for i in range(len(y_array)):
self.re_lmList[i][1] = x_array[i]
else:
self.re_lmList = self.lmList
return self.re_lmList
@ -263,6 +282,8 @@ class Main:
self.camera = None
self.detector = HandDetector()
self.default_datasets = "Datasets"
self.len_x = 22
self.len_y = 4
def make_datasets(self, datasets_dir="default", n=100):
if datasets_dir == "default":
@ -276,24 +297,23 @@ class Main:
self.camera.set(4, 720)
label = input("label:")
while not label == "":
data = np.zeros([n, 21, 3])
data = np.zeros([n, self.len_x, self.len_y])
shape_list = np.zeros([n, 2], dtype=np.int16)
hand_type = np.zeros(n, dtype=np.int8)
zero_data = np.zeros([21, 3])
count = 0
cv2.startWindowThread()
while True:
frame, img = self.camera.read()
img = self.detector.find_hands(img)
result = np.zeros((21, 3))
result = np.zeros((self.len_x, self.len_y))
lm_list, bbox = self.detector.find_position(img)
for i in range(len(lm_list)):
result[i] = np.array(lm_list[i])
shape = bbox["shape"]
if result.sum() > 0: # 假设矩阵不为0,即捕捉到手部时
if result.all() != zero_data.all(): # 假设矩阵不为0,即捕捉到手部时
shape = bbox["shape"]
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
data[count] = result
hand_type[count] = self.detector.hand_type()
@ -332,20 +352,26 @@ class Main:
self.detector = HandDetector()
cnn = torch.load("CNN.pkl")
out_label = cnn.out_label
result = []
disp = ""
while True:
frame, img = self.camera.read()
img = self.detector.find_hands(img)
lm_list, bbox = self.detector.find_position(img)
if lm_list:
if lm_list.any():
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
data = torch.Tensor(lm_list)
data = data.unsqueeze(0)
data = data.unsqueeze(0)
test_output = cnn(data)
result = torch.max(test_output, 1)[1].data.cpu().numpy()[0]
cv2.putText(img, str(out_label[result]), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
result.append(torch.max(test_output, 1)[1].data.cpu().numpy()[0])
if len(result) > 5:
disp = str(out_label[stats.mode(result)[0][0]])
result = []
cv2.putText(img, disp, (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3)
cv2.imshow("camera", img)
@ -359,6 +385,6 @@ class Main:
if __name__ == '__main__':
solution = Main()
my_datasets_dir = "test"
solution.make_datasets(my_datasets_dir, 50)
solution.make_datasets(my_datasets_dir, 100)
solution.train(my_datasets_dir)
solution.gesture_recognition()

Loading…
Cancel
Save