Browse Source

feat: can count squat rep

Léo Salé 3 năm trước cách đây
mục cha
commit
b6c426536e

+ 87 - 67
app/lib/exercises/exercises_validation/widgets/pose_detector.dart

@@ -1,21 +1,26 @@
 import 'dart:async';
-import 'dart:io';
 
-import 'package:body_detection/models/point3d.dart';
 import 'package:body_detection/models/pose_landmark.dart';
 import 'package:body_detection/models/pose_landmark_type.dart';
+import 'package:physigo/navigation/utils/geometry_utils.dart';
 import 'package:rxdart/rxdart.dart';
 import 'package:body_detection/body_detection.dart';
 import 'package:body_detection/models/image_result.dart';
 import 'package:body_detection/models/pose.dart';
 import 'package:flutter/material.dart';
-import 'package:path_provider/path_provider.dart';
 
 import 'pose_painter.dart';
 
 typedef MeanFilteredData = Iterable<List<double>>;
 typedef LandmarkVariations = List<List<double>>;
 
+enum StepExercise {
+  notInPlace,
+  ready,
+  start,
+  end,
+}
+
 class PoseDetector extends StatefulWidget {
   const PoseDetector({Key? key}) : super(key: key);
 
@@ -25,51 +30,69 @@ class PoseDetector extends StatefulWidget {
 
 class _PoseDetectorState extends State<PoseDetector> {
   static const buffer = 10;
-  static const _shouldWriteToFile = false;
   final StreamController<Pose> _streamController = StreamController.broadcast();
-  final Directory appDir = Directory('/storage/emulated/0/Android/data/com.example.physigo/files');
   Image? _cameraImage;
   Pose? _detectedPose;
   Size _imageSize = Size.zero;
   late Future<void> _startCamera;
-  late Stream<LandmarkVariations> _variationsStream;
   late Stream<MeanFilteredData> _meanFilterStream;
-  StreamController<Color> _stepExerciseStream = StreamController.broadcast();
+  StreamController<StepExercise> _stepExerciseStream = StreamController.broadcast();
+  late Stream<int> _repCounter;
 
   @override
   initState() {
     super.initState();
     _startCamera = _startCameraStream();
     _meanFilterStream = _getMeanFilterStream(_streamController.stream);
-    if (_shouldWriteToFile) {
-      _writeDataToFile(_meanFilterStream);
-    }
-    _variationsStream = _meanFilterStream.pairwise().map(_calculateVariations);
+    CombineLatestStream.combine2(_stepExerciseStream.stream, _meanFilterStream, (a, b) => [a, b]).listen((value) {
+      StepExercise stepExercise = value.first as StepExercise;
+      MeanFilteredData meanFilteredData = value.last as MeanFilteredData;
+      final isStartOfExerciseMovement = _isAtStartOfExerciseMovement(meanFilteredData);
+      final isEndOfExerciseMovement = _isAtEndOfExerciseMovement(meanFilteredData);
+      if (stepExercise == StepExercise.notInPlace && isStartOfExerciseMovement) {
+        _stepExerciseStream.add(StepExercise.ready);
+      }
+      if ((stepExercise == StepExercise.ready || stepExercise == StepExercise.start) && isEndOfExerciseMovement) {
+        _stepExerciseStream.add(StepExercise.end);
+      }
+      if (stepExercise == StepExercise.end && isStartOfExerciseMovement) {
+        _stepExerciseStream.add(StepExercise.start);
+      }
+    });
+    _stepExerciseStream.add(StepExercise.notInPlace);
+    _repCounter = _stepExerciseStream.stream
+        .where((event) => event == StepExercise.start)
+        .scan((int accumulated, value, index) => accumulated + 1, 0);
   }
 
-  LandmarkVariations _calculateVariations(Iterable<MeanFilteredData> pairPositions) {
-    final previous = pairPositions.first.toList();
-    final current = pairPositions.last.toList();
-    LandmarkVariations variations = [];
-    for (int landmark = 0; landmark < previous.length; landmark++) {
-      final dx = current[landmark][0] - previous[landmark][0];
-      final dy = current[landmark][1] - previous[landmark][1];
-      final dz = current[landmark][2] - previous[landmark][2];
-      variations.add([dx.roundToDouble(), dy.roundToDouble(), dz.roundToDouble()]);
+  bool _isAtStartOfExerciseMovement(MeanFilteredData meanFilteredData) {
+    final landmarks = meanFilteredData.toList();
+    final rightShoulder = Point3D(x: landmarks[2][0], y: landmarks[2][1], z: landmarks[2][2]);
+    final rightHip = Point3D(x: landmarks[8][0], y: landmarks[8][1], z: landmarks[8][2]);
+    final rightKnee = Point3D(x: landmarks[10][0], y: landmarks[10][1], z: landmarks[10][2]);
+    final angleRight = DistanceUtils.angleBetweenThreePoints(rightShoulder, rightHip, rightKnee).round();
+    final leftShoulder = Point3D(x: landmarks[1][0], y: landmarks[1][1], z: landmarks[1][2]);
+    final leftHip = Point3D(x: landmarks[7][0], y: landmarks[7][1], z: landmarks[7][2]);
+    final leftKnee = Point3D(x: landmarks[9][0], y: landmarks[9][1], z: landmarks[9][2]);
+    final angleLeft = DistanceUtils.angleBetweenThreePoints(leftShoulder, leftHip, leftKnee).round();
+    if (angleLeft > 320 && angleRight > 320) {
+      return true;
     }
-    return variations;
+    return false;
   }
 
-  void _writeDataToFile(Stream<MeanFilteredData> stream) {
-    File meanFilteredData = File("${appDir.path}/meanFilteredData.csv");
-    if (meanFilteredData.existsSync()) meanFilteredData.deleteSync();
-    stream.listen((meanPositions) {
-      for (var position in meanPositions) {
-        final str = "${position[0]}, ${position[1]}, ${position[2]};";
-        meanFilteredData.writeAsStringSync(str, mode: FileMode.append);
-      }
-      meanFilteredData.writeAsStringSync("\n", mode: FileMode.append);
-    });
+  bool _isAtEndOfExerciseMovement(MeanFilteredData meanFilteredData) {
+    final landmarks = meanFilteredData.toList();
+    final yRightHip = landmarks[8][1];
+    final yRightKnee = landmarks[10][1];
+    final yDistanceRightHipKnee = (yRightHip - yRightKnee).abs();
+    final yLeftHip = landmarks[8][1];
+    final yLeftKnee = landmarks[10][1];
+    final yDistanceLeftHipKnee = (yLeftHip - yLeftKnee).abs();
+    if (yDistanceRightHipKnee < 40 && yDistanceLeftHipKnee < 40) {
+      return true;
+    }
+    return false;
   }
 
   Stream<MeanFilteredData> _getMeanFilterStream(Stream<Pose> stream) {
@@ -178,51 +201,48 @@ class _PoseDetectorState extends State<PoseDetector> {
                 ),
               ),
             ),
-            StreamBuilder<MeanFilteredData>(
-              stream: _meanFilterStream,
-              builder: (context, snapshot) {
-                if (!snapshot.hasData) {
-                  return CircularProgressIndicator();
-                }
-                final landmarks = snapshot.data!.toList();
-                final xRightHip = landmarks[8][0];
-                final xRightKnee = landmarks[10][0];
-                final xDistanceHipKnee = (xRightHip - xRightKnee).abs();
-
-                final yRightHip = landmarks[8][1];
-                final yRightKnee = landmarks[10][1];
-                final yDistanceHipKnee = (yRightHip - yRightKnee).abs();
-                var message = "IN BETWEEN";
-                if (xDistanceHipKnee < 30) {
-                  message = "START";
-                  _stepExerciseStream.add(Colors.green);
-                } else if (yDistanceHipKnee < 40) {
-                  message = "END";
-                  _stepExerciseStream.add(Colors.red);
-                } else {
-                  _stepExerciseStream.add(Colors.yellow);
-                }
-
-                final zRightHip = landmarks[8][2];
-                final zRightKnee = landmarks[10][2];
-                final zDistanceHipKnee = (zRightHip - zRightKnee).abs();
-
-                return Text("$zDistanceHipKnee", style: TextStyle(fontSize: 40));
-              },
-            ),
-            StreamBuilder<Color>(
+            StreamBuilder<StepExercise>(
               stream: _stepExerciseStream.stream,
               builder: (context, snapshot) {
+                Color color;
                 if (!snapshot.hasData) {
-                  return CircularProgressIndicator();
+                  color = Colors.black;
+                } else {
+                  switch (snapshot.data!) {
+                    case StepExercise.notInPlace:
+                      color = Colors.black;
+                      break;
+                    case StepExercise.ready:
+                      color = Colors.green;
+                      break;
+                    case StepExercise.start:
+                      color = Colors.blue;
+                      break;
+                    case StepExercise.end:
+                      color = Colors.red;
+                      break;
+                  }
                 }
                 return Container(
                   height: 100,
                   width: 100,
-                  color: snapshot.data!,
+                  color: color,
                 );
               },
             ),
+            StreamBuilder<int>(
+              stream: _repCounter,
+              builder: (context, snapshot) {
+                var repCounter = 0;
+                if (snapshot.hasData) {
+                  repCounter = snapshot.data!;
+                }
+                return Text(
+                  "$repCounter",
+                  style: TextStyle(fontSize: 40),
+                );
+              },
+            )
           ],
         );
       },
@@ -251,7 +271,7 @@ class _PoseDetectorState extends State<PoseDetector> {
 
 GETTING IN POSITION:
   CHECK IF EVERY NECESSARY JOINT ARE ON SCREEN (reliability > 0.8)
-  CHECK IF START POSITION IS OKAY (for squat, if knee and hip on same x coordinate)
+  CHECK IF START POSITION IS OKAY (for squat, if knee, hip, shoulder aligned)
 
 COUTING REPETITION:
   FROM BEGINNING TO END:

+ 18 - 17
app/lib/main.dart

@@ -34,10 +34,11 @@ class HomePage extends StatelessWidget {
   @override
   Widget build(BuildContext context) {
     return Scaffold(
-      body: Center(
-        child: Column(
-          children: [
-            TextButton(
+      body: Column(
+        mainAxisAlignment: MainAxisAlignment.center,
+        children: [
+          Center(
+            child: TextButton(
               onPressed: () {
                 Navigator.push(
                   context,
@@ -51,19 +52,19 @@ class HomePage extends StatelessWidget {
               },
               child: const Text('Navigation'),
             ),
-            TextButton(
-              onPressed: () {
-                Navigator.push(
-                  context,
-                  MaterialPageRoute(
-                    builder: (context) => ExerciseValidationPage(),
-                  ),
-                );
-              },
-              child: const Text('Exercise Validation'),
-            ),
-          ],
-        ),
+          ),
+          TextButton(
+            onPressed: () {
+              Navigator.push(
+                context,
+                MaterialPageRoute(
+                  builder: (context) => ExerciseValidationPage(),
+                ),
+              );
+            },
+            child: const Text('Exercise Validation'),
+          ),
+        ],
       ),
     );
   }

+ 13 - 0
app/lib/navigation/utils/geometry_utils.dart

@@ -102,4 +102,17 @@ class DistanceUtils {
   static num _dotProduct(Point3D v, Point3D w) {
     return (v.x * w.x) + (v.y * w.y) + (v.z * w.z);
   }
+
+  static num angleBetweenThreePoints(Point3D start, Point3D center, Point3D end) {
+    final a = center - start;
+    final b = center - end;
+    final aLength = _lengthVector(a);
+    final bLength = _lengthVector(b);
+    final dotProduct = _dotProduct(a, b);
+    return acos(dotProduct / (aLength * bLength)) * 360 / pi;
+  }
+
+  static num _lengthVector(Point3D vector) {
+    return sqrt(_sqr(vector.x) + _sqr(vector.y) + _sqr(vector.z));
+  }
 }