增加:用户自定义手势
This commit is contained in:
parent
5f9d6f5abc
commit
ff20ecd6ff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Datasets/one.npz
BIN
Datasets/one.npz
Binary file not shown.
Binary file not shown.
BIN
Datasets/six.npz
BIN
Datasets/six.npz
Binary file not shown.
Binary file not shown.
BIN
Datasets/two.npz
BIN
Datasets/two.npz
Binary file not shown.
Binary file not shown.
158
demo.py
158
demo.py
@ -12,12 +12,15 @@ import mediapipe as mp
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import shutil
|
||||||
|
from os.path import exists
|
||||||
|
from os import mkdir
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
|
|
||||||
class CNN(nn.Module):
|
class CNN(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, m):
|
||||||
super(CNN, self).__init__()
|
super(CNN, self).__init__()
|
||||||
self.out_label = []
|
self.out_label = []
|
||||||
self.conv1 = nn.Sequential(
|
self.conv1 = nn.Sequential(
|
||||||
@ -37,13 +40,17 @@ class CNN(nn.Module):
|
|||||||
nn.MaxPool2d(3),
|
nn.MaxPool2d(3),
|
||||||
)
|
)
|
||||||
self.med = nn.Linear(32 * 7 * 1, 500)
|
self.med = nn.Linear(32 * 7 * 1, 500)
|
||||||
self.out = nn.Linear(500, 10) # fully connected layer, output 10 classes
|
self.med2 = nn.Linear(1*21*3, 100)
|
||||||
|
self.med3 = nn.Linear(100, 500)
|
||||||
|
self.out = nn.Linear(500, m) # fully connected layer, output 10 classes
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = x.view(x.size(0), -1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
|
x = x.view(x.size(0), -1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
|
||||||
x = self.med(x)
|
x = self.med(x)
|
||||||
|
# x = self.med2(x)
|
||||||
|
# x = self.med3(x)
|
||||||
output = self.out(x)
|
output = self.out(x)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -104,19 +111,20 @@ class HandDetector:
|
|||||||
:param draw: 在图像上绘制输出的标志。(默认绘制矩形框)
|
:param draw: 在图像上绘制输出的标志。(默认绘制矩形框)
|
||||||
:return: 像素格式的手部关节位置列表;手部边界框
|
:return: 像素格式的手部关节位置列表;手部边界框
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x_list = []
|
x_list = []
|
||||||
y_list = []
|
y_list = []
|
||||||
|
one_data = np.zeros([21, 3])
|
||||||
bbox_info = []
|
bbox_info = []
|
||||||
self.lmList = []
|
self.lmList = []
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
if self.results.multi_hand_landmarks:
|
if self.results.multi_hand_landmarks:
|
||||||
my_hand = self.results.multi_hand_landmarks[hand_no]
|
my_hand = self.results.multi_hand_landmarks[hand_no]
|
||||||
for _, lm in enumerate(my_hand.landmark):
|
for i, lm in enumerate(my_hand.landmark):
|
||||||
px, py = int(lm.x * w), int(lm.y * h)
|
px, py = int(lm.x * w), int(lm.y * h)
|
||||||
x_list.append(px)
|
x_list.append(px)
|
||||||
y_list.append(py)
|
y_list.append(py)
|
||||||
self.lmList.append([lm.x, lm.y, lm.z])
|
self.lmList.append([lm.x, lm.y, lm.z])
|
||||||
|
one_data[i] = np.array([lm.x, lm.y, lm.z])
|
||||||
if draw:
|
if draw:
|
||||||
cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED)
|
cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED)
|
||||||
x_min, x_max = min(x_list), max(x_list)
|
x_min, x_max = min(x_list), max(x_list)
|
||||||
@ -131,83 +139,52 @@ class HandDetector:
|
|||||||
(bbox[0] + bbox[2] + 20, bbox[1] + bbox[3] + 20),
|
(bbox[0] + bbox[2] + 20, bbox[1] + bbox[3] + 20),
|
||||||
(0, 255, 0), 2)
|
(0, 255, 0), 2)
|
||||||
|
|
||||||
return self.lmList, bbox_info
|
return one_data, (h, w), self.lmList, bbox_info
|
||||||
|
|
||||||
def fingers_up(self):
|
|
||||||
"""
|
|
||||||
查找列表中打开并返回的手指数。会分别考虑左手和右手
|
|
||||||
:return: 竖起手指的列表
|
|
||||||
"""
|
|
||||||
fingers = []
|
|
||||||
if self.results.multi_hand_landmarks:
|
|
||||||
my_hand_type = self.hand_type()
|
|
||||||
# Thumb
|
|
||||||
if my_hand_type == "Right":
|
|
||||||
if self.lmList[self.tipIds[0]][0] > self.lmList[self.tipIds[0] - 1][0]:
|
|
||||||
fingers.append(1)
|
|
||||||
else:
|
|
||||||
fingers.append(0)
|
|
||||||
else:
|
|
||||||
if self.lmList[self.tipIds[0]][0] < self.lmList[self.tipIds[0] - 1][0]:
|
|
||||||
fingers.append(1)
|
|
||||||
else:
|
|
||||||
fingers.append(0)
|
|
||||||
# 4 Fingers
|
|
||||||
for i in range(1, 5):
|
|
||||||
if self.lmList[self.tipIds[i]][1] < self.lmList[self.tipIds[i] - 2][1]:
|
|
||||||
fingers.append(1)
|
|
||||||
else:
|
|
||||||
fingers.append(0)
|
|
||||||
return fingers
|
|
||||||
|
|
||||||
def hand_type(self):
|
def hand_type(self):
|
||||||
"""
|
"""
|
||||||
检查传入的手部是左还是右
|
检查传入的手部 是左还是右
|
||||||
:return: "Right" 或 "Left"
|
:return: 1 或 0
|
||||||
"""
|
"""
|
||||||
if self.results.multi_hand_landmarks:
|
if self.results.multi_hand_landmarks:
|
||||||
if self.lmList[17][0] < self.lmList[5][0]:
|
if self.lmList[17][0] < self.lmList[5][0]:
|
||||||
return "Right"
|
return 1
|
||||||
else:
|
else:
|
||||||
return "Left"
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class Main:
|
class AI:
|
||||||
def __init__(self):
|
def __init__(self, datasets_dir):
|
||||||
self.EPOCH = 50
|
self.EPOCH = 20
|
||||||
self.BATCH_SIZE = 5
|
self.BATCH_SIZE = 2
|
||||||
self.LR = 10e-5
|
self.LR = 10e-5
|
||||||
self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.datasets_dir = datasets_dir
|
||||||
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
|
||||||
self.camera.set(3, 1280)
|
|
||||||
self.camera.set(4, 720)
|
|
||||||
|
|
||||||
self.datasets_dir = "Datasets"
|
|
||||||
self.train_loader = None
|
self.train_loader = None
|
||||||
|
self.m = 0
|
||||||
self.out_label = [] # CNN网络输出后数字标签转和字符串标签的映射关系
|
self.out_label = [] # CNN网络输出后数字标签转和字符串标签的映射关系
|
||||||
|
|
||||||
self.detector = None
|
|
||||||
|
|
||||||
def load_datasets(self):
|
def load_datasets(self):
|
||||||
train_data = []
|
train_data = []
|
||||||
train_label = []
|
train_label = []
|
||||||
|
self.m = 0
|
||||||
for file in Path(self.datasets_dir).rglob("*.npz"):
|
for file in Path(self.datasets_dir).rglob("*.npz"):
|
||||||
data = np.load(str(file))
|
data = np.load(str(file))
|
||||||
train_data.append(data["data"])
|
train_data.append(data["data"])
|
||||||
label_number = np.ones(len(data["data"]))*len(self.out_label)
|
label_number = np.ones(len(data["data"])) * len(self.out_label)
|
||||||
train_label.append(label_number)
|
train_label.append(label_number)
|
||||||
self.out_label.append(data["label"])
|
self.out_label.append(data["label"])
|
||||||
|
self.m += 1
|
||||||
train_data = torch.Tensor(np.concatenate(train_data, axis=0))
|
train_data = torch.Tensor(np.concatenate(train_data, axis=0))
|
||||||
train_data = train_data.unsqueeze(1)
|
train_data = train_data.unsqueeze(1)
|
||||||
train_label = torch.tensor(np.concatenate(train_label, axis=0)).long()
|
train_label = torch.tensor(np.concatenate(train_label, axis=0)).long()
|
||||||
|
|
||||||
dataset = TensorDataset(train_data, train_label)
|
dataset = TensorDataset(train_data, train_label)
|
||||||
self.train_loader = DataLoader(dataset, batch_size=self.BATCH_SIZE, shuffle=True)
|
self.train_loader = DataLoader(dataset, batch_size=self.BATCH_SIZE, shuffle=True)
|
||||||
|
return self.m
|
||||||
|
|
||||||
def train_cnn(self):
|
def train_cnn(self):
|
||||||
cnn = CNN().to(self.DEVICE)
|
cnn = CNN(self.m).to(self.DEVICE)
|
||||||
optimizer = torch.optim.Adam(cnn.parameters(), self.LR) # optimize all cnn parameters
|
optimizer = torch.optim.Adam(cnn.parameters(), self.LR) # optimize all cnn parameters
|
||||||
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
||||||
|
|
||||||
@ -236,14 +213,80 @@ class Main:
|
|||||||
torch.save(cnn, 'CNN.pkl')
|
torch.save(cnn, 'CNN.pkl')
|
||||||
print("训练结束")
|
print("训练结束")
|
||||||
|
|
||||||
|
|
||||||
|
class Main:
|
||||||
|
def __init__(self):
|
||||||
|
self.camera = None
|
||||||
|
self.detector = HandDetector()
|
||||||
|
self.default_datasets = "Datasets"
|
||||||
|
|
||||||
|
def make_datasets(self, datasets_dir="default", n=100):
|
||||||
|
if datasets_dir == "default":
|
||||||
|
return
|
||||||
|
if exists(datasets_dir):
|
||||||
|
shutil.rmtree(datasets_dir)
|
||||||
|
mkdir(datasets_dir)
|
||||||
|
if self.camera is None:
|
||||||
|
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
||||||
|
self.camera.set(3, 1280)
|
||||||
|
self.camera.set(4, 720)
|
||||||
|
label = input("label:")
|
||||||
|
while not label == "":
|
||||||
|
data = np.zeros([n, 21, 3])
|
||||||
|
shape_list = np.zeros([n, 2], dtype=np.int16)
|
||||||
|
hand_type = np.zeros(n, dtype=np.int8)
|
||||||
|
|
||||||
|
zero_data = np.zeros([21, 3])
|
||||||
|
count = 0
|
||||||
|
cv2.startWindowThread()
|
||||||
|
while True:
|
||||||
|
frame, img = self.camera.read()
|
||||||
|
img = self.detector.find_hands(img)
|
||||||
|
|
||||||
|
result, shape, _, bbox = self.detector.find_position(img)
|
||||||
|
if result.all() != zero_data.all(): # 假设矩阵不为0,即捕捉到手部时
|
||||||
|
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
|
||||||
|
data[count] = result
|
||||||
|
hand_type[count] = self.detector.hand_type()
|
||||||
|
shape_list[count] = np.array(shape)
|
||||||
|
count += 1
|
||||||
|
cv2.putText(img, str("{}/{}".format(count, n)), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||||
|
(0, 255, 0), 3)
|
||||||
|
|
||||||
|
cv2.imshow("camera", img)
|
||||||
|
key = cv2.waitKey(100)
|
||||||
|
if cv2.getWindowProperty('camera', cv2.WND_PROP_VISIBLE) < 1:
|
||||||
|
break
|
||||||
|
elif key == 27:
|
||||||
|
break
|
||||||
|
elif count == n - 1:
|
||||||
|
break
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
open(datasets_dir + "/" + label + ".npz", "w")
|
||||||
|
np.savez(datasets_dir + "/" + label + ".npz", label=label, data=data,
|
||||||
|
handtype=hand_type, shape=shape_list)
|
||||||
|
label = input("label:")
|
||||||
|
|
||||||
|
def train(self, datasets_dir="default"):
|
||||||
|
if datasets_dir == "default":
|
||||||
|
datasets_dir = self.default_datasets
|
||||||
|
ai = AI(datasets_dir)
|
||||||
|
ai.load_datasets()
|
||||||
|
ai.train_cnn()
|
||||||
|
|
||||||
def gesture_recognition(self):
|
def gesture_recognition(self):
|
||||||
|
if self.camera is None:
|
||||||
|
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
||||||
|
self.camera.set(3, 1280)
|
||||||
|
self.camera.set(4, 720)
|
||||||
self.detector = HandDetector()
|
self.detector = HandDetector()
|
||||||
cnn = torch.load("CNN.pkl")
|
cnn = torch.load("CNN.pkl")
|
||||||
out_label = cnn.out_label
|
out_label = cnn.out_label
|
||||||
while True:
|
while True:
|
||||||
frame, img = self.camera.read()
|
frame, img = self.camera.read()
|
||||||
img = self.detector.find_hands(img)
|
img = self.detector.find_hands(img)
|
||||||
lm_list, bbox = self.detector.find_position(img)
|
_, _, lm_list, bbox = self.detector.find_position(img)
|
||||||
|
|
||||||
if lm_list:
|
if lm_list:
|
||||||
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
|
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
|
||||||
@ -254,7 +297,7 @@ class Main:
|
|||||||
test_output = cnn(data)
|
test_output = cnn(data)
|
||||||
result = torch.max(test_output, 1)[1].data.cpu().numpy()[0]
|
result = torch.max(test_output, 1)[1].data.cpu().numpy()[0]
|
||||||
cv2.putText(img, str(out_label[result]), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
cv2.putText(img, str(out_label[result]), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||||
(0, 0, 255), 3)
|
(0, 0, 255), 3)
|
||||||
|
|
||||||
cv2.imshow("camera", img)
|
cv2.imshow("camera", img)
|
||||||
key = cv2.waitKey(1)
|
key = cv2.waitKey(1)
|
||||||
@ -265,7 +308,8 @@ class Main:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
Solution = Main()
|
solution = Main()
|
||||||
Solution.load_datasets()
|
my_datasets_dir = "test"
|
||||||
Solution.train_cnn()
|
solution.make_datasets(my_datasets_dir, 200)
|
||||||
Solution.gesture_recognition()
|
solution.train(my_datasets_dir)
|
||||||
|
solution.gesture_recognition()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user