增加:CNN网络测试(基本流程已完成)
附:训练结果
This commit is contained in:
parent
ea6ee46971
commit
5f9d6f5abc
68
demo.py
68
demo.py
@ -20,21 +20,21 @@ class CNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(CNN, self).__init__()
|
||||
self.out_label = []
|
||||
self.conv1 = nn.Sequential( # input shape (1, 21, 3)
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1, # input height
|
||||
out_channels=16, # n_filters
|
||||
kernel_size=5, # filter size
|
||||
stride=1, # filter movement/step
|
||||
padding=2, # 如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1
|
||||
), # output shape (16, 28, 28)
|
||||
nn.ReLU(), # activation
|
||||
nn.MaxPool2d(kernel_size=1), # 在 2x2 空间里向下采样, output shape (16, 14, 14)
|
||||
in_channels=1,
|
||||
out_channels=16,
|
||||
kernel_size=5,
|
||||
stride=1,
|
||||
padding=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=1),
|
||||
)
|
||||
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
|
||||
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
|
||||
nn.ReLU(), # activation
|
||||
nn.MaxPool2d(3), # output shape (32, 7, 7)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(16, 32, 5, 1, 2),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(3),
|
||||
)
|
||||
self.med = nn.Linear(32 * 7 * 1, 500)
|
||||
self.out = nn.Linear(500, 10) # fully connected layer, output 10 classes
|
||||
@ -109,14 +109,14 @@ class HandDetector:
|
||||
y_list = []
|
||||
bbox_info = []
|
||||
self.lmList = []
|
||||
h, w, c = img.shape
|
||||
if self.results.multi_hand_landmarks:
|
||||
my_hand = self.results.multi_hand_landmarks[hand_no]
|
||||
for _, lm in enumerate(my_hand.landmark):
|
||||
h, w, c = img.shape
|
||||
px, py = int(lm.x * w), int(lm.y * h)
|
||||
x_list.append(px)
|
||||
y_list.append(py)
|
||||
self.lmList.append([px, py])
|
||||
self.lmList.append([lm.x, lm.y, lm.z])
|
||||
if draw:
|
||||
cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED)
|
||||
x_min, x_max = min(x_list), max(x_list)
|
||||
@ -174,8 +174,8 @@ class HandDetector:
|
||||
|
||||
class Main:
|
||||
def __init__(self):
|
||||
self.EPOCH = 20
|
||||
self.BATCH_SIZE = 10
|
||||
self.EPOCH = 50
|
||||
self.BATCH_SIZE = 5
|
||||
self.LR = 10e-5
|
||||
self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@ -238,6 +238,8 @@ class Main:
|
||||
|
||||
def gesture_recognition(self):
|
||||
self.detector = HandDetector()
|
||||
cnn = torch.load("CNN.pkl")
|
||||
out_label = cnn.out_label
|
||||
while True:
|
||||
frame, img = self.camera.read()
|
||||
img = self.detector.find_hands(img)
|
||||
@ -245,31 +247,13 @@ class Main:
|
||||
|
||||
if lm_list:
|
||||
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
|
||||
x1, x2, x3, x4, x5 = self.detector.fingers_up()
|
||||
data = torch.Tensor(lm_list)
|
||||
data = data.unsqueeze(0)
|
||||
data = data.unsqueeze(0)
|
||||
|
||||
if (x2 == 1 and x3 == 1) and (x4 == 0 and x5 == 0 and x1 == 0):
|
||||
cv2.putText(img, "2_TWO", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
(0, 0, 255), 3)
|
||||
elif (x2 == 1 and x3 == 1 and x4 == 1) and (x1 == 0 and x5 == 0):
|
||||
cv2.putText(img, "3_THREE", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
(0, 0, 255), 3)
|
||||
elif (x2 == 1 and x3 == 1 and x4 == 1 and x5 == 1) and (x1 == 0):
|
||||
cv2.putText(img, "4_FOUR", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
(0, 0, 255), 3)
|
||||
elif x1 == 1 and x2 == 1 and x3 == 1 and x4 == 1 and x5 == 1:
|
||||
cv2.putText(img, "5_FIVE", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
(0, 0, 255), 3)
|
||||
elif x2 == 1 and x1 == 0 and (x3 == 0, x4 == 0, x5 == 0):
|
||||
cv2.putText(img, "1_ONE", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
(0, 0, 255), 3)
|
||||
elif x1 == 1 and x2 == 1 and (x3 == 0, x4 == 0, x5 == 0):
|
||||
cv2.putText(img, "8_EIGHT", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
(0, 0, 255), 3)
|
||||
elif x1 == 1 and x5 == 1 and (x3 == 0, x4 == 0, x5 == 0):
|
||||
cv2.putText(img, "6_SIX", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
(0, 0, 255), 3)
|
||||
elif x1 and (x2 == 0, x3 == 0, x4 == 0, x5 == 0):
|
||||
cv2.putText(img, "GOOD!", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
test_output = cnn(data)
|
||||
result = torch.max(test_output, 1)[1].data.cpu().numpy()[0]
|
||||
cv2.putText(img, str(out_label[result]), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
|
||||
(0, 0, 255), 3)
|
||||
|
||||
cv2.imshow("camera", img)
|
||||
@ -282,6 +266,6 @@ class Main:
|
||||
|
||||
if __name__ == '__main__':
|
||||
Solution = Main()
|
||||
# Solution.gesture_recognition()
|
||||
Solution.load_datasets()
|
||||
Solution.train_cnn()
|
||||
Solution.gesture_recognition()
|
||||
|
Loading…
x
Reference in New Issue
Block a user