1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
| #ifndef MEDIAPIPE_CALCULATORS_INFERENCE_INFERENCE_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_INFERENCE_INFERENCE_CALCULATOR_H_
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/tensor.h" #include "inference_calculator_options.pb.h"
namespace mediapipe {
class ConfigurableInferenceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("IMAGE").Set<ImageFrame>(); cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>(); cc->Options<InferenceCalculatorOptions>(); cc->InputSidePackets().Tag("OVERRIDE_THRESHOLD").Set<float>().Optional(); cc->InputSidePackets().Tag("MODEL_PATH").Set<std::string>(); return absl::OkStatus(); }
absl::Status Open(CalculatorContext* cc) override { const auto& options = cc->Options<InferenceCalculatorOptions>(); RET_CHECK(options.has_input_size()) << "input_size is required"; model_path_ = cc->InputSidePackets().Tag("MODEL_PATH").Get<std::string>(); num_threads_ = options.num_threads(); backend_ = options.backend(); input_width_ = options.input_size().width(); input_height_ = options.input_size().height(); score_threshold_ = options.score_threshold(); iou_threshold_ = options.iou_threshold(); max_detections_ = options.max_detections(); if (cc->InputSidePackets().HasTag("OVERRIDE_THRESHOLD")) { score_threshold_ = cc->InputSidePackets().Tag("OVERRIDE_THRESHOLD").Get<float>(); LOG(INFO) << "Score threshold overridden to: " << score_threshold_; } RET_CHECK(!model_path_.empty()) << "model_path is required"; RET_CHECK(input_width_ > 0 && input_height_ > 0) << "Invalid input size"; RET_CHECK(score_threshold_ > 0.0f && score_threshold_ < 1.0f) << "Invalid score threshold"; RET_CHECK(iou_threshold_ > 0.0f && iou_threshold_ < 1.0f) << "Invalid IoU threshold"; MP_RETURN_IF_ERROR(LoadModel(model_path_, backend_, num_threads_)); LOG(INFO) << "ConfigurableInferenceCalculator initialized: " << "model=" << model_path_ << ", input_size=" << input_width_ << "x" << input_height_ << ", threshold=" << score_threshold_ << ", backend=" << Backend_Name(backend_); return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().Tag("IMAGE").IsEmpty()) { return absl::OkStatus(); } const ImageFrame& image = cc->Inputs().Tag("IMAGE").Get<ImageFrame>(); cv::Mat input = Preprocess(image, input_width_, input_height_); auto raw_output = Inference(input); std::vector<Detection> detections = Postprocess( raw_output, score_threshold_, iou_threshold_, max_detections_); cc->Outputs().Tag("DETECTIONS").AddPacket( MakePacket<std::vector<Detection>>(detections).At(cc->InputTimestamp())); process_count_++; return absl::OkStatus(); }
absl::Status Close(CalculatorContext* cc) override { if (interpreter_) { interpreter_->Reset(); } LOG(INFO) << "ConfigurableInferenceCalculator closed, processed " << process_count_ << " frames"; return absl::OkStatus(); }
private: std::string model_path_; int num_threads_ = 4; int input_width_ = 320; int input_height_ = 320; float score_threshold_ = 0.5f; float iou_threshold_ = 0.45f; int max_detections_ = 100; InferenceCalculatorOptions::Backend backend_ = InferenceCalculatorOptions::CPU; std::unique_ptr<tflite::Interpreter> interpreter_; int process_count_ = 0; absl::Status LoadModel(const std::string& path, InferenceCalculatorOptions::Backend backend, int num_threads); cv::Mat Preprocess(const ImageFrame& image, int width, int height); std::vector<float> Inference(const cv::Mat& input); std::vector<Detection> Postprocess(const std::vector<float>& output, float score_threshold, float iou_threshold, int max_detections); };
REGISTER_CALCULATOR(ConfigurableInferenceCalculator);
}
#endif
|