pose_detector.dart 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import 'dart:async';
  2. import 'dart:io';
  3. import 'package:body_detection/models/point3d.dart';
  4. import 'package:body_detection/models/pose_landmark.dart';
  5. import 'package:body_detection/models/pose_landmark_type.dart';
  6. import 'package:rxdart/rxdart.dart';
  7. import 'package:body_detection/body_detection.dart';
  8. import 'package:body_detection/models/image_result.dart';
  9. import 'package:body_detection/models/pose.dart';
  10. import 'package:flutter/material.dart';
  11. import 'package:path_provider/path_provider.dart';
  12. import 'pose_painter.dart';
  13. typedef MeanFilteredData = Iterable<List<double>>;
  14. typedef LandmarkVariations = List<List<double>>;
  15. class PoseDetector extends StatefulWidget {
  16. const PoseDetector({Key? key}) : super(key: key);
  17. @override
  18. State<PoseDetector> createState() => _PoseDetectorState();
  19. }
  20. class _PoseDetectorState extends State<PoseDetector> {
  21. static const buffer = 10;
  22. static const _shouldWriteToFile = false;
  23. final StreamController<Pose> _streamController = StreamController.broadcast();
  24. final Directory appDir = Directory('/storage/emulated/0/Android/data/com.example.physigo/files');
  25. Image? _cameraImage;
  26. Pose? _detectedPose;
  27. Size _imageSize = Size.zero;
  28. late Future<void> _startCamera;
  29. late Stream<LandmarkVariations> _variationsStream;
  30. late Stream<MeanFilteredData> _meanFilterStream;
  31. StreamController<Color> _stepExerciseStream = StreamController.broadcast();
  32. @override
  33. initState() {
  34. super.initState();
  35. _startCamera = _startCameraStream();
  36. _meanFilterStream = _getMeanFilterStream(_streamController.stream);
  37. if (_shouldWriteToFile) {
  38. _writeDataToFile(_meanFilterStream);
  39. }
  40. _variationsStream = _meanFilterStream.pairwise().map(_calculateVariations);
  41. }
  42. LandmarkVariations _calculateVariations(Iterable<MeanFilteredData> pairPositions) {
  43. final previous = pairPositions.first.toList();
  44. final current = pairPositions.last.toList();
  45. LandmarkVariations variations = [];
  46. for (int landmark = 0; landmark < previous.length; landmark++) {
  47. final dx = current[landmark][0] - previous[landmark][0];
  48. final dy = current[landmark][1] - previous[landmark][1];
  49. final dz = current[landmark][2] - previous[landmark][2];
  50. variations.add([dx.roundToDouble(), dy.roundToDouble(), dz.roundToDouble()]);
  51. }
  52. return variations;
  53. }
  54. void _writeDataToFile(Stream<MeanFilteredData> stream) {
  55. File meanFilteredData = File("${appDir.path}/meanFilteredData.csv");
  56. if (meanFilteredData.existsSync()) meanFilteredData.deleteSync();
  57. stream.listen((meanPositions) {
  58. for (var position in meanPositions) {
  59. final str = "${position[0]}, ${position[1]}, ${position[2]};";
  60. meanFilteredData.writeAsStringSync(str, mode: FileMode.append);
  61. }
  62. meanFilteredData.writeAsStringSync("\n", mode: FileMode.append);
  63. });
  64. }
  65. Stream<MeanFilteredData> _getMeanFilterStream(Stream<Pose> stream) {
  66. return stream
  67. .where((pose) => pose.landmarks.isNotEmpty)
  68. .map((pose) => pose.landmarks.where((landmark) => authorizedType.contains(landmark.type)).toList())
  69. // Get last [buffer] poses
  70. .bufferCount(buffer, 1)
  71. // Swap matrix [buffer] * [authorizedType.length]
  72. .map(_swapMatrixDimensions)
  73. // For every landmarks, get meanFilter of size [buffer]
  74. .map((filteredLandmarks) => filteredLandmarks.map(_meanFilter));
  75. }
  76. List<double> _meanFilter(List<PoseLandmark> landmarks) {
  77. return landmarks
  78. .map((landmark) => landmark.position)
  79. .map((position) => [
  80. position.x / buffer,
  81. position.y / buffer,
  82. position.z / buffer,
  83. ])
  84. .reduce((value, element) => [
  85. value[0] + element[0],
  86. value[1] + element[1],
  87. value[2] + element[2],
  88. ]);
  89. }
  90. List<List<T>> _swapMatrixDimensions<T>(List<List<T>> matrix) {
  91. final height = matrix.length;
  92. final width = matrix[0].length;
  93. List<List<T>> newMatrix = [];
  94. for (int col = 0; col < width; col++) {
  95. List<T> newRow = [];
  96. for (int row = 0; row < height; row++) {
  97. newRow.add(matrix[row][col]);
  98. }
  99. newMatrix.add(newRow);
  100. }
  101. return newMatrix;
  102. }
  103. Future<void> _startCameraStream() async {
  104. await BodyDetection.startCameraStream(onFrameAvailable: _handleCameraImage, onPoseAvailable: _handlePose);
  105. await BodyDetection.enablePoseDetection();
  106. }
  107. Future<void> _stopCameraStream() async {
  108. await BodyDetection.disablePoseDetection();
  109. await BodyDetection.stopCameraStream();
  110. }
  111. void _handleCameraImage(ImageResult result) {
  112. if (!mounted) return;
  113. // To avoid a memory leak issue.
  114. // https://github.com/flutter/flutter/issues/60160
  115. PaintingBinding.instance?.imageCache?.clear();
  116. PaintingBinding.instance?.imageCache?.clearLiveImages();
  117. final image = Image.memory(
  118. result.bytes,
  119. gaplessPlayback: true,
  120. fit: BoxFit.contain,
  121. );
  122. setState(() {
  123. _cameraImage = image;
  124. _imageSize = result.size;
  125. });
  126. }
  127. void _handlePose(Pose? pose) {
  128. if (!mounted) return;
  129. if (pose != null) _streamController.add(pose);
  130. setState(() {
  131. _detectedPose = pose;
  132. });
  133. }
  134. @override
  135. void dispose() {
  136. _stopCameraStream();
  137. _streamController.close();
  138. super.dispose();
  139. }
  140. @override
  141. Widget build(BuildContext context) {
  142. return FutureBuilder<void>(
  143. future: _startCamera,
  144. builder: (context, snapshot) {
  145. if (snapshot.connectionState == ConnectionState.waiting) {
  146. return const Center(child: CircularProgressIndicator());
  147. }
  148. return Column(
  149. children: [
  150. Center(
  151. child: CustomPaint(
  152. // size: _imageSize,
  153. child: _cameraImage,
  154. foregroundPainter: PosePainter(
  155. pose: _detectedPose,
  156. imageSize: _imageSize,
  157. ),
  158. ),
  159. ),
  160. StreamBuilder<MeanFilteredData>(
  161. stream: _meanFilterStream,
  162. builder: (context, snapshot) {
  163. if (!snapshot.hasData) {
  164. return CircularProgressIndicator();
  165. }
  166. final landmarks = snapshot.data!.toList();
  167. final xRightHip = landmarks[8][0];
  168. final xRightKnee = landmarks[10][0];
  169. final xDistanceHipKnee = (xRightHip - xRightKnee).abs();
  170. final yRightHip = landmarks[8][1];
  171. final yRightKnee = landmarks[10][1];
  172. final yDistanceHipKnee = (yRightHip - yRightKnee).abs();
  173. var message = "IN BETWEEN";
  174. if (xDistanceHipKnee < 30) {
  175. message = "START";
  176. _stepExerciseStream.add(Colors.green);
  177. } else if (yDistanceHipKnee < 40) {
  178. message = "END";
  179. _stepExerciseStream.add(Colors.red);
  180. } else {
  181. _stepExerciseStream.add(Colors.yellow);
  182. }
  183. final zRightHip = landmarks[8][2];
  184. final zRightKnee = landmarks[10][2];
  185. final zDistanceHipKnee = (zRightHip - zRightKnee).abs();
  186. return Text("$zDistanceHipKnee", style: TextStyle(fontSize: 40));
  187. },
  188. ),
  189. StreamBuilder<Color>(
  190. stream: _stepExerciseStream.stream,
  191. builder: (context, snapshot) {
  192. if (!snapshot.hasData) {
  193. return CircularProgressIndicator();
  194. }
  195. return Container(
  196. height: 100,
  197. width: 100,
  198. color: snapshot.data!,
  199. );
  200. },
  201. ),
  202. ],
  203. );
  204. },
  205. );
  206. }
  207. static const authorizedType = [
  208. PoseLandmarkType.nose,
  209. PoseLandmarkType.leftShoulder,
  210. PoseLandmarkType.rightShoulder,
  211. PoseLandmarkType.leftElbow,
  212. PoseLandmarkType.rightElbow,
  213. PoseLandmarkType.leftWrist,
  214. PoseLandmarkType.rightWrist,
  215. PoseLandmarkType.leftHip,
  216. PoseLandmarkType.rightHip,
  217. PoseLandmarkType.leftKnee,
  218. PoseLandmarkType.rightKnee,
  219. PoseLandmarkType.leftAnkle,
  220. PoseLandmarkType.rightAnkle,
  221. ];
  222. }
  223. /*
  224. GETTING IN POSITION:
  225. CHECK IF EVERY NECESSARY JOINT ARE ON SCREEN (reliability > 0.8)
  226. CHECK IF START POSITION IS OKAY (for squat, if knee and hip on same x coordinate)
  227. COUTING REPETITION:
  228. FROM BEGINNING TO END:
  229. - BEGINNING: defined by start position, get position of interesting joint
  230. - END: defined by positions/distance interesting joints (knee and hip same level for squat,
  231. elbow and should same level for push up)
  232. */