修改:增加众数判定
This commit is contained in:
parent
8bfe21e11b
commit
96c6039474
58
ai.py
58
ai.py
@ -14,6 +14,7 @@ import torch.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import shutil
|
import shutil
|
||||||
import math
|
import math
|
||||||
|
from scipy import stats
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
from os import mkdir
|
from os import mkdir
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -27,6 +28,13 @@ def rotate(angle, x, y, point_x, point_y):
|
|||||||
return px, py
|
return px, py
|
||||||
|
|
||||||
|
|
||||||
|
# 归一化
|
||||||
|
def normalize(x):
|
||||||
|
max_x = np.max(x)
|
||||||
|
min_x = np.min(x)
|
||||||
|
return (x-min_x)/(max_x-min_x)
|
||||||
|
|
||||||
|
|
||||||
class CNN(nn.Module):
|
class CNN(nn.Module):
|
||||||
def __init__(self, m):
|
def __init__(self, m):
|
||||||
super(CNN, self).__init__()
|
super(CNN, self).__init__()
|
||||||
@ -45,9 +53,9 @@ class CNN(nn.Module):
|
|||||||
self.conv2 = nn.Sequential(
|
self.conv2 = nn.Sequential(
|
||||||
nn.Conv2d(16, 32, 5, 1, 2),
|
nn.Conv2d(16, 32, 5, 1, 2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.MaxPool2d(3),
|
nn.MaxPool2d(2),
|
||||||
)
|
)
|
||||||
self.med = nn.Linear(32 * 7 * 1, 500)
|
self.med = nn.Linear(32 * 11 * 2, 500)
|
||||||
self.med2 = nn.Linear(1*21*3, 100)
|
self.med2 = nn.Linear(1*21*3, 100)
|
||||||
self.med3 = nn.Linear(100, 500)
|
self.med3 = nn.Linear(100, 500)
|
||||||
self.out = nn.Linear(500, m) # fully connected layer, output 10 classes
|
self.out = nn.Linear(500, m) # fully connected layer, output 10 classes
|
||||||
@ -131,7 +139,7 @@ class HandDetector:
|
|||||||
px, py = int(lm.x * w), int(lm.y * h)
|
px, py = int(lm.x * w), int(lm.y * h)
|
||||||
x_list.append(px)
|
x_list.append(px)
|
||||||
y_list.append(py)
|
y_list.append(py)
|
||||||
self.lmList.append([lm.x, lm.y, lm.z])
|
self.lmList.append([lm.x, lm.y, 0])
|
||||||
if draw:
|
if draw:
|
||||||
cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED)
|
cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED)
|
||||||
x_min, x_max = min(x_list), max(x_list)
|
x_min, x_max = min(x_list), max(x_list)
|
||||||
@ -147,6 +155,10 @@ class HandDetector:
|
|||||||
(0, 255, 0), 2)
|
(0, 255, 0), 2)
|
||||||
|
|
||||||
self.revolve(img)
|
self.revolve(img)
|
||||||
|
self.re_lmList = np.array(self.re_lmList)
|
||||||
|
if self.re_lmList.any():
|
||||||
|
self.re_lmList = np.concatenate((np.zeros((21, 1)), self.re_lmList), axis=1)
|
||||||
|
self.re_lmList = np.concatenate((self.re_lmList, np.zeros((1, 4))), axis=0)
|
||||||
|
|
||||||
return self.re_lmList, bbox_info
|
return self.re_lmList, bbox_info
|
||||||
|
|
||||||
@ -176,11 +188,18 @@ class HandDetector:
|
|||||||
theta = theta + math.pi
|
theta = theta + math.pi
|
||||||
# print(theta*180/math.pi)
|
# print(theta*180/math.pi)
|
||||||
for i in self.lmList:
|
for i in self.lmList:
|
||||||
x, y = rotate(theta, i[0], i[1], point_x, point_y)
|
px, py = rotate(theta, i[0] * w, i[1] * h, point_x * w, point_y * h)
|
||||||
px, py = int(x * w), int(y * h)
|
self.re_lmList.append([px, py, 0])
|
||||||
self.re_lmList.append([x, y, i[2]])
|
|
||||||
if draw:
|
if draw:
|
||||||
cv2.circle(img, (px, py), 5, (0, 0, 255), cv2.FILLED)
|
cv2.circle(img, (int(px), int(py)), 5, (0, 0, 255), cv2.FILLED)
|
||||||
|
# 归一化
|
||||||
|
x_array = normalize(np.array(self.re_lmList)[:, 0])
|
||||||
|
# print(x_array)
|
||||||
|
for i in range(len(x_array)):
|
||||||
|
self.re_lmList[i][0] = x_array[i]
|
||||||
|
y_array = normalize(np.array(self.re_lmList)[:, 1])
|
||||||
|
for i in range(len(y_array)):
|
||||||
|
self.re_lmList[i][1] = x_array[i]
|
||||||
else:
|
else:
|
||||||
self.re_lmList = self.lmList
|
self.re_lmList = self.lmList
|
||||||
return self.re_lmList
|
return self.re_lmList
|
||||||
@ -263,6 +282,8 @@ class Main:
|
|||||||
self.camera = None
|
self.camera = None
|
||||||
self.detector = HandDetector()
|
self.detector = HandDetector()
|
||||||
self.default_datasets = "Datasets"
|
self.default_datasets = "Datasets"
|
||||||
|
self.len_x = 22
|
||||||
|
self.len_y = 4
|
||||||
|
|
||||||
def make_datasets(self, datasets_dir="default", n=100):
|
def make_datasets(self, datasets_dir="default", n=100):
|
||||||
if datasets_dir == "default":
|
if datasets_dir == "default":
|
||||||
@ -276,24 +297,23 @@ class Main:
|
|||||||
self.camera.set(4, 720)
|
self.camera.set(4, 720)
|
||||||
label = input("label:")
|
label = input("label:")
|
||||||
while not label == "":
|
while not label == "":
|
||||||
data = np.zeros([n, 21, 3])
|
data = np.zeros([n, self.len_x, self.len_y])
|
||||||
shape_list = np.zeros([n, 2], dtype=np.int16)
|
shape_list = np.zeros([n, 2], dtype=np.int16)
|
||||||
hand_type = np.zeros(n, dtype=np.int8)
|
hand_type = np.zeros(n, dtype=np.int8)
|
||||||
|
|
||||||
zero_data = np.zeros([21, 3])
|
|
||||||
count = 0
|
count = 0
|
||||||
cv2.startWindowThread()
|
cv2.startWindowThread()
|
||||||
while True:
|
while True:
|
||||||
frame, img = self.camera.read()
|
frame, img = self.camera.read()
|
||||||
img = self.detector.find_hands(img)
|
img = self.detector.find_hands(img)
|
||||||
result = np.zeros((21, 3))
|
result = np.zeros((self.len_x, self.len_y))
|
||||||
|
|
||||||
lm_list, bbox = self.detector.find_position(img)
|
lm_list, bbox = self.detector.find_position(img)
|
||||||
for i in range(len(lm_list)):
|
for i in range(len(lm_list)):
|
||||||
result[i] = np.array(lm_list[i])
|
result[i] = np.array(lm_list[i])
|
||||||
shape = bbox["shape"]
|
if result.sum() > 0: # 假设矩阵不为0,即捕捉到手部时
|
||||||
|
|
||||||
if result.all() != zero_data.all(): # 假设矩阵不为0,即捕捉到手部时
|
shape = bbox["shape"]
|
||||||
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
|
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
|
||||||
data[count] = result
|
data[count] = result
|
||||||
hand_type[count] = self.detector.hand_type()
|
hand_type[count] = self.detector.hand_type()
|
||||||
@ -332,20 +352,26 @@ class Main:
|
|||||||
self.detector = HandDetector()
|
self.detector = HandDetector()
|
||||||
cnn = torch.load("CNN.pkl")
|
cnn = torch.load("CNN.pkl")
|
||||||
out_label = cnn.out_label
|
out_label = cnn.out_label
|
||||||
|
result = []
|
||||||
|
disp = ""
|
||||||
while True:
|
while True:
|
||||||
frame, img = self.camera.read()
|
frame, img = self.camera.read()
|
||||||
img = self.detector.find_hands(img)
|
img = self.detector.find_hands(img)
|
||||||
lm_list, bbox = self.detector.find_position(img)
|
lm_list, bbox = self.detector.find_position(img)
|
||||||
|
|
||||||
if lm_list:
|
if lm_list.any():
|
||||||
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
|
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
|
||||||
data = torch.Tensor(lm_list)
|
data = torch.Tensor(lm_list)
|
||||||
data = data.unsqueeze(0)
|
data = data.unsqueeze(0)
|
||||||
data = data.unsqueeze(0)
|
data = data.unsqueeze(0)
|
||||||
|
|
||||||
test_output = cnn(data)
|
test_output = cnn(data)
|
||||||
result = torch.max(test_output, 1)[1].data.cpu().numpy()[0]
|
result.append(torch.max(test_output, 1)[1].data.cpu().numpy()[0])
|
||||||
cv2.putText(img, str(out_label[result]), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
if len(result) > 5:
|
||||||
|
disp = str(out_label[stats.mode(result)[0][0]])
|
||||||
|
result = []
|
||||||
|
|
||||||
|
cv2.putText(img, disp, (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||||
(0, 0, 255), 3)
|
(0, 0, 255), 3)
|
||||||
|
|
||||||
cv2.imshow("camera", img)
|
cv2.imshow("camera", img)
|
||||||
@ -359,6 +385,6 @@ class Main:
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
solution = Main()
|
solution = Main()
|
||||||
my_datasets_dir = "test"
|
my_datasets_dir = "test"
|
||||||
solution.make_datasets(my_datasets_dir, 50)
|
solution.make_datasets(my_datasets_dir, 100)
|
||||||
solution.train(my_datasets_dir)
|
solution.train(my_datasets_dir)
|
||||||
solution.gesture_recognition()
|
solution.gesture_recognition()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user