diff --git a/CNN.pkl b/CNN.pkl index ff3a1d4..20d36c5 100644 Binary files a/CNN.pkl and b/CNN.pkl differ diff --git a/Datasets/eight.npz b/Datasets/eight.npz index da3d73f..5ef32f8 100644 Binary files a/Datasets/eight.npz and b/Datasets/eight.npz differ diff --git a/Datasets/five.npz b/Datasets/five.npz index 99e85d0..6c43052 100644 Binary files a/Datasets/five.npz and b/Datasets/five.npz differ diff --git a/Datasets/four.npz b/Datasets/four.npz index d58ccfe..e4e23cf 100644 Binary files a/Datasets/four.npz and b/Datasets/four.npz differ diff --git a/Datasets/nine.npz b/Datasets/nine.npz index 978dd07..3e00d80 100644 Binary files a/Datasets/nine.npz and b/Datasets/nine.npz differ diff --git a/Datasets/one.npz b/Datasets/one.npz index 1e3a850..a7d14ed 100644 Binary files a/Datasets/one.npz and b/Datasets/one.npz differ diff --git a/Datasets/seven.npz b/Datasets/seven.npz index 394e5e5..ec45515 100644 Binary files a/Datasets/seven.npz and b/Datasets/seven.npz differ diff --git a/Datasets/six.npz b/Datasets/six.npz index f16f347..3f88257 100644 Binary files a/Datasets/six.npz and b/Datasets/six.npz differ diff --git a/Datasets/three.npz b/Datasets/three.npz index 3c6ce91..24bc764 100644 Binary files a/Datasets/three.npz and b/Datasets/three.npz differ diff --git a/Datasets/two.npz b/Datasets/two.npz index 620c5fa..816829e 100644 Binary files a/Datasets/two.npz and b/Datasets/two.npz differ diff --git a/Datasets/zero.npz b/Datasets/zero.npz index 0f2054f..0c48068 100644 Binary files a/Datasets/zero.npz and b/Datasets/zero.npz differ diff --git a/demo.py b/demo.py index 974e1f8..a1e5bd2 100644 --- a/demo.py +++ b/demo.py @@ -12,12 +12,15 @@ import mediapipe as mp import torch import torch.nn as nn import numpy as np +import shutil +from os.path import exists +from os import mkdir from pathlib import Path from torch.utils.data import DataLoader, TensorDataset class CNN(nn.Module): - def __init__(self): + def __init__(self, m): super(CNN, self).__init__() self.out_label = [] self.conv1 = nn.Sequential( @@ -37,13 +40,17 @@ class CNN(nn.Module): nn.MaxPool2d(3), ) self.med = nn.Linear(32 * 7 * 1, 500) - self.out = nn.Linear(500, 10) # fully connected layer, output 10 classes + self.med2 = nn.Linear(1*21*3, 100) + self.med3 = nn.Linear(100, 500) + self.out = nn.Linear(500, m) # fully connected layer, output 10 classes def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7) x = self.med(x) + # x = self.med2(x) + # x = self.med3(x) output = self.out(x) return output @@ -104,19 +111,20 @@ class HandDetector: :param draw: 在图像上绘制输出的标志。(默认绘制矩形框) :return: 像素格式的手部关节位置列表;手部边界框 """ - x_list = [] y_list = [] + one_data = np.zeros([21, 3]) bbox_info = [] self.lmList = [] h, w, c = img.shape if self.results.multi_hand_landmarks: my_hand = self.results.multi_hand_landmarks[hand_no] - for _, lm in enumerate(my_hand.landmark): + for i, lm in enumerate(my_hand.landmark): px, py = int(lm.x * w), int(lm.y * h) x_list.append(px) y_list.append(py) self.lmList.append([lm.x, lm.y, lm.z]) + one_data[i] = np.array([lm.x, lm.y, lm.z]) if draw: cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED) x_min, x_max = min(x_list), max(x_list) @@ -131,83 +139,52 @@ class HandDetector: (bbox[0] + bbox[2] + 20, bbox[1] + bbox[3] + 20), (0, 255, 0), 2) - return self.lmList, bbox_info - - def fingers_up(self): - """ - 查找列表中打开并返回的手指数。会分别考虑左手和右手 - :return: 竖起手指的列表 - """ - fingers = [] - if self.results.multi_hand_landmarks: - my_hand_type = self.hand_type() - # Thumb - if my_hand_type == "Right": - if self.lmList[self.tipIds[0]][0] > self.lmList[self.tipIds[0] - 1][0]: - fingers.append(1) - else: - fingers.append(0) - else: - if self.lmList[self.tipIds[0]][0] < self.lmList[self.tipIds[0] - 1][0]: - fingers.append(1) - else: - fingers.append(0) - # 4 Fingers - for i in range(1, 5): - if self.lmList[self.tipIds[i]][1] < self.lmList[self.tipIds[i] - 2][1]: - fingers.append(1) - else: - fingers.append(0) - return fingers + return one_data, (h, w), self.lmList, bbox_info def hand_type(self): """ - 检查传入的手部是左还是右 - :return: "Right" 或 "Left" + 检查传入的手部 是左还是右 + :return: 1 或 0 """ if self.results.multi_hand_landmarks: if self.lmList[17][0] < self.lmList[5][0]: - return "Right" + return 1 else: - return "Left" + return 0 -class Main: - def __init__(self): - self.EPOCH = 50 - self.BATCH_SIZE = 5 +class AI: + def __init__(self, datasets_dir): + self.EPOCH = 20 + self.BATCH_SIZE = 2 self.LR = 10e-5 self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW) - self.camera.set(3, 1280) - self.camera.set(4, 720) - - self.datasets_dir = "Datasets" + self.datasets_dir = datasets_dir self.train_loader = None + self.m = 0 self.out_label = [] # CNN网络输出后数字标签转和字符串标签的映射关系 - self.detector = None - def load_datasets(self): train_data = [] train_label = [] - + self.m = 0 for file in Path(self.datasets_dir).rglob("*.npz"): data = np.load(str(file)) train_data.append(data["data"]) - label_number = np.ones(len(data["data"]))*len(self.out_label) + label_number = np.ones(len(data["data"])) * len(self.out_label) train_label.append(label_number) self.out_label.append(data["label"]) + self.m += 1 train_data = torch.Tensor(np.concatenate(train_data, axis=0)) train_data = train_data.unsqueeze(1) train_label = torch.tensor(np.concatenate(train_label, axis=0)).long() dataset = TensorDataset(train_data, train_label) self.train_loader = DataLoader(dataset, batch_size=self.BATCH_SIZE, shuffle=True) + return self.m def train_cnn(self): - cnn = CNN().to(self.DEVICE) + cnn = CNN(self.m).to(self.DEVICE) optimizer = torch.optim.Adam(cnn.parameters(), self.LR) # optimize all cnn parameters loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted @@ -236,14 +213,80 @@ class Main: torch.save(cnn, 'CNN.pkl') print("训练结束") + +class Main: + def __init__(self): + self.camera = None + self.detector = HandDetector() + self.default_datasets = "Datasets" + + def make_datasets(self, datasets_dir="default", n=100): + if datasets_dir == "default": + return + if exists(datasets_dir): + shutil.rmtree(datasets_dir) + mkdir(datasets_dir) + if self.camera is None: + self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW) + self.camera.set(3, 1280) + self.camera.set(4, 720) + label = input("label:") + while not label == "": + data = np.zeros([n, 21, 3]) + shape_list = np.zeros([n, 2], dtype=np.int16) + hand_type = np.zeros(n, dtype=np.int8) + + zero_data = np.zeros([21, 3]) + count = 0 + cv2.startWindowThread() + while True: + frame, img = self.camera.read() + img = self.detector.find_hands(img) + + result, shape, _, bbox = self.detector.find_position(img) + if result.all() != zero_data.all(): # 假设矩阵不为0,即捕捉到手部时 + x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1] + data[count] = result + hand_type[count] = self.detector.hand_type() + shape_list[count] = np.array(shape) + count += 1 + cv2.putText(img, str("{}/{}".format(count, n)), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3, + (0, 255, 0), 3) + + cv2.imshow("camera", img) + key = cv2.waitKey(100) + if cv2.getWindowProperty('camera', cv2.WND_PROP_VISIBLE) < 1: + break + elif key == 27: + break + elif count == n - 1: + break + cv2.destroyAllWindows() + + open(datasets_dir + "/" + label + ".npz", "w") + np.savez(datasets_dir + "/" + label + ".npz", label=label, data=data, + handtype=hand_type, shape=shape_list) + label = input("label:") + + def train(self, datasets_dir="default"): + if datasets_dir == "default": + datasets_dir = self.default_datasets + ai = AI(datasets_dir) + ai.load_datasets() + ai.train_cnn() + def gesture_recognition(self): + if self.camera is None: + self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW) + self.camera.set(3, 1280) + self.camera.set(4, 720) self.detector = HandDetector() cnn = torch.load("CNN.pkl") out_label = cnn.out_label while True: frame, img = self.camera.read() img = self.detector.find_hands(img) - lm_list, bbox = self.detector.find_position(img) + _, _, lm_list, bbox = self.detector.find_position(img) if lm_list: x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1] @@ -254,7 +297,7 @@ class Main: test_output = cnn(data) result = torch.max(test_output, 1)[1].data.cpu().numpy()[0] cv2.putText(img, str(out_label[result]), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3, - (0, 0, 255), 3) + (0, 0, 255), 3) cv2.imshow("camera", img) key = cv2.waitKey(1) @@ -265,7 +308,8 @@ class Main: if __name__ == '__main__': - Solution = Main() - Solution.load_datasets() - Solution.train_cnn() - Solution.gesture_recognition() + solution = Main() + my_datasets_dir = "test" + solution.make_datasets(my_datasets_dir, 200) + solution.train(my_datasets_dir) + solution.gesture_recognition()