pose_detector.dart 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import 'dart:async';
  2. import 'package:body_detection/models/pose_landmark.dart';
  3. import 'package:body_detection/models/pose_landmark_type.dart';
  4. import 'package:physigo/navigation/utils/geometry_utils.dart';
  5. import 'package:rxdart/rxdart.dart';
  6. import 'package:body_detection/body_detection.dart';
  7. import 'package:body_detection/models/image_result.dart';
  8. import 'package:body_detection/models/pose.dart';
  9. import 'package:flutter/material.dart';
  10. import 'pose_painter.dart';
  11. typedef MeanFilteredData = Iterable<List<double>>;
  12. typedef LandmarkVariations = List<List<double>>;
  13. enum StepExercise {
  14. notInPlace,
  15. ready,
  16. start,
  17. end,
  18. }
  19. class PoseDetector extends StatefulWidget {
  20. const PoseDetector({Key? key}) : super(key: key);
  21. @override
  22. State<PoseDetector> createState() => _PoseDetectorState();
  23. }
  24. class _PoseDetectorState extends State<PoseDetector> {
  25. static const buffer = 10;
  26. final StreamController<Pose> _streamController = StreamController.broadcast();
  27. Image? _cameraImage;
  28. Pose? _detectedPose;
  29. Size _imageSize = Size.zero;
  30. late Future<void> _startCamera;
  31. late Stream<MeanFilteredData> _meanFilterStream;
  32. StreamController<StepExercise> _stepExerciseStream = StreamController.broadcast();
  33. late Stream<int> _repCounter;
  34. @override
  35. initState() {
  36. super.initState();
  37. _startCamera = _startCameraStream();
  38. _meanFilterStream = _getMeanFilterStream(_streamController.stream);
  39. CombineLatestStream.combine2(_stepExerciseStream.stream, _meanFilterStream, (a, b) => [a, b]).listen((value) {
  40. StepExercise stepExercise = value.first as StepExercise;
  41. MeanFilteredData meanFilteredData = value.last as MeanFilteredData;
  42. final isStartOfExerciseMovement = _isAtStartOfExerciseMovement(meanFilteredData);
  43. final isEndOfExerciseMovement = _isAtEndOfExerciseMovement(meanFilteredData);
  44. if (stepExercise == StepExercise.notInPlace && isStartOfExerciseMovement) {
  45. _stepExerciseStream.add(StepExercise.ready);
  46. }
  47. if ((stepExercise == StepExercise.ready || stepExercise == StepExercise.start) && isEndOfExerciseMovement) {
  48. _stepExerciseStream.add(StepExercise.end);
  49. }
  50. if (stepExercise == StepExercise.end && isStartOfExerciseMovement) {
  51. _stepExerciseStream.add(StepExercise.start);
  52. }
  53. });
  54. _stepExerciseStream.add(StepExercise.notInPlace);
  55. _repCounter = _stepExerciseStream.stream
  56. .where((event) => event == StepExercise.start)
  57. .scan((int accumulated, value, index) => accumulated + 1, 0);
  58. }
  59. bool _isAtStartOfExerciseMovement(MeanFilteredData meanFilteredData) {
  60. final landmarks = meanFilteredData.toList();
  61. final rightShoulder = Point3D(x: landmarks[2][0], y: landmarks[2][1], z: landmarks[2][2]);
  62. final rightHip = Point3D(x: landmarks[8][0], y: landmarks[8][1], z: landmarks[8][2]);
  63. final rightKnee = Point3D(x: landmarks[10][0], y: landmarks[10][1], z: landmarks[10][2]);
  64. final angleRight = DistanceUtils.angleBetweenThreePoints(rightShoulder, rightHip, rightKnee).round();
  65. final leftShoulder = Point3D(x: landmarks[1][0], y: landmarks[1][1], z: landmarks[1][2]);
  66. final leftHip = Point3D(x: landmarks[7][0], y: landmarks[7][1], z: landmarks[7][2]);
  67. final leftKnee = Point3D(x: landmarks[9][0], y: landmarks[9][1], z: landmarks[9][2]);
  68. final angleLeft = DistanceUtils.angleBetweenThreePoints(leftShoulder, leftHip, leftKnee).round();
  69. if (angleLeft > 320 && angleRight > 320) {
  70. return true;
  71. }
  72. return false;
  73. }
  74. bool _isAtEndOfExerciseMovement(MeanFilteredData meanFilteredData) {
  75. final landmarks = meanFilteredData.toList();
  76. final yRightHip = landmarks[8][1];
  77. final yRightKnee = landmarks[10][1];
  78. final yDistanceRightHipKnee = (yRightHip - yRightKnee).abs();
  79. final yLeftHip = landmarks[8][1];
  80. final yLeftKnee = landmarks[10][1];
  81. final yDistanceLeftHipKnee = (yLeftHip - yLeftKnee).abs();
  82. if (yDistanceRightHipKnee < 40 && yDistanceLeftHipKnee < 40) {
  83. return true;
  84. }
  85. return false;
  86. }
  87. Stream<MeanFilteredData> _getMeanFilterStream(Stream<Pose> stream) {
  88. return stream
  89. .where((pose) => pose.landmarks.isNotEmpty)
  90. .map((pose) => pose.landmarks.where((landmark) => authorizedType.contains(landmark.type)).toList())
  91. // Get last [buffer] poses
  92. .bufferCount(buffer, 1)
  93. // Swap matrix [buffer] * [authorizedType.length]
  94. .map(_swapMatrixDimensions)
  95. // For every landmarks, get meanFilter of size [buffer]
  96. .map((filteredLandmarks) => filteredLandmarks.map(_meanFilter));
  97. }
  98. List<double> _meanFilter(List<PoseLandmark> landmarks) {
  99. return landmarks
  100. .map((landmark) => landmark.position)
  101. .map((position) => [
  102. position.x / buffer,
  103. position.y / buffer,
  104. position.z / buffer,
  105. ])
  106. .reduce((value, element) => [
  107. value[0] + element[0],
  108. value[1] + element[1],
  109. value[2] + element[2],
  110. ]);
  111. }
  112. List<List<T>> _swapMatrixDimensions<T>(List<List<T>> matrix) {
  113. final height = matrix.length;
  114. final width = matrix[0].length;
  115. List<List<T>> newMatrix = [];
  116. for (int col = 0; col < width; col++) {
  117. List<T> newRow = [];
  118. for (int row = 0; row < height; row++) {
  119. newRow.add(matrix[row][col]);
  120. }
  121. newMatrix.add(newRow);
  122. }
  123. return newMatrix;
  124. }
  125. Future<void> _startCameraStream() async {
  126. await BodyDetection.startCameraStream(onFrameAvailable: _handleCameraImage, onPoseAvailable: _handlePose);
  127. await BodyDetection.enablePoseDetection();
  128. }
  129. Future<void> _stopCameraStream() async {
  130. await BodyDetection.disablePoseDetection();
  131. await BodyDetection.stopCameraStream();
  132. }
  133. void _handleCameraImage(ImageResult result) {
  134. if (!mounted) return;
  135. // To avoid a memory leak issue.
  136. // https://github.com/flutter/flutter/issues/60160
  137. PaintingBinding.instance?.imageCache?.clear();
  138. PaintingBinding.instance?.imageCache?.clearLiveImages();
  139. final image = Image.memory(
  140. result.bytes,
  141. gaplessPlayback: true,
  142. fit: BoxFit.contain,
  143. );
  144. setState(() {
  145. _cameraImage = image;
  146. _imageSize = result.size;
  147. });
  148. }
  149. void _handlePose(Pose? pose) {
  150. if (!mounted) return;
  151. if (pose != null) _streamController.add(pose);
  152. setState(() {
  153. _detectedPose = pose;
  154. });
  155. }
  156. @override
  157. void dispose() {
  158. _stopCameraStream();
  159. _streamController.close();
  160. super.dispose();
  161. }
  162. @override
  163. Widget build(BuildContext context) {
  164. return FutureBuilder<void>(
  165. future: _startCamera,
  166. builder: (context, snapshot) {
  167. if (snapshot.connectionState == ConnectionState.waiting) {
  168. return const Center(child: CircularProgressIndicator());
  169. }
  170. return Column(
  171. children: [
  172. Center(
  173. child: CustomPaint(
  174. // size: _imageSize,
  175. child: _cameraImage,
  176. foregroundPainter: PosePainter(
  177. pose: _detectedPose,
  178. imageSize: _imageSize,
  179. ),
  180. ),
  181. ),
  182. StreamBuilder<StepExercise>(
  183. stream: _stepExerciseStream.stream,
  184. builder: (context, snapshot) {
  185. Color color;
  186. if (!snapshot.hasData) {
  187. color = Colors.black;
  188. } else {
  189. switch (snapshot.data!) {
  190. case StepExercise.notInPlace:
  191. color = Colors.black;
  192. break;
  193. case StepExercise.ready:
  194. color = Colors.green;
  195. break;
  196. case StepExercise.start:
  197. color = Colors.blue;
  198. break;
  199. case StepExercise.end:
  200. color = Colors.red;
  201. break;
  202. }
  203. }
  204. return Container(
  205. height: 100,
  206. width: 100,
  207. color: color,
  208. );
  209. },
  210. ),
  211. StreamBuilder<int>(
  212. stream: _repCounter,
  213. builder: (context, snapshot) {
  214. var repCounter = 0;
  215. if (snapshot.hasData) {
  216. repCounter = snapshot.data!;
  217. }
  218. return Text(
  219. "$repCounter",
  220. style: TextStyle(fontSize: 40),
  221. );
  222. },
  223. )
  224. ],
  225. );
  226. },
  227. );
  228. }
  229. static const authorizedType = [
  230. PoseLandmarkType.nose,
  231. PoseLandmarkType.leftShoulder,
  232. PoseLandmarkType.rightShoulder,
  233. PoseLandmarkType.leftElbow,
  234. PoseLandmarkType.rightElbow,
  235. PoseLandmarkType.leftWrist,
  236. PoseLandmarkType.rightWrist,
  237. PoseLandmarkType.leftHip,
  238. PoseLandmarkType.rightHip,
  239. PoseLandmarkType.leftKnee,
  240. PoseLandmarkType.rightKnee,
  241. PoseLandmarkType.leftAnkle,
  242. PoseLandmarkType.rightAnkle,
  243. ];
  244. }
  245. /*
  246. GETTING IN POSITION:
  247. CHECK IF EVERY NECESSARY JOINT ARE ON SCREEN (reliability > 0.8)
  248. CHECK IF START POSITION IS OKAY (for squat, if knee, hip, shoulder aligned)
  249. COUTING REPETITION:
  250. FROM BEGINNING TO END:
  251. - BEGINNING: defined by start position, get position of interesting joint
  252. - END: defined by positions/distance interesting joints (knee and hip same level for squat,
  253. elbow and should same level for push up)
  254. */