ESPHome  2024.11.0
streaming_model.h
Go to the documentation of this file.
1 #pragma once
2 
3 #ifdef USE_ESP_IDF
4 
6 
7 #include <tensorflow/lite/core/c/common.h>
8 #include <tensorflow/lite/micro/micro_interpreter.h>
9 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
10 
11 namespace esphome {
12 namespace micro_wake_word {
13 
14 static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024;
15 
17  public:
18  virtual void log_model_config() = 0;
19  virtual bool determine_detected() = 0;
20 
21  bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]);
22 
24  void reset_probabilities();
25 
29  bool load_model(tflite::MicroMutableOpResolver<20> &op_resolver);
30 
32  void unload_model();
33 
34  protected:
36 
39  size_t last_n_index_{0};
41  std::vector<uint8_t> recent_streaming_probabilities_;
42 
43  const uint8_t *model_start_;
44  uint8_t *tensor_arena_{nullptr};
45  uint8_t *var_arena_{nullptr};
46  std::unique_ptr<tflite::MicroInterpreter> interpreter_;
47  tflite::MicroResourceVariables *mrv_{nullptr};
48  tflite::MicroAllocator *ma_{nullptr};
49 };
50 
51 class WakeWordModel final : public StreamingModel {
52  public:
53  WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
54  const std::string &wake_word, size_t tensor_arena_size);
55 
56  void log_model_config() override;
57 
61  bool determine_detected() override;
62 
63  const std::string &get_wake_word() const { return this->wake_word_; }
64 
65  protected:
66  std::string wake_word_;
67 };
68 
69 class VADModel final : public StreamingModel {
70  public:
71  VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size);
72 
73  void log_model_config() override;
74 
78  bool determine_detected() override;
79 };
80 
81 } // namespace micro_wake_word
82 } // namespace esphome
83 
84 #endif
const std::string & get_wake_word() const
tflite::MicroResourceVariables * mrv_
void unload_model()
Destroys the TFLite interpreter and frees the tensor and variable arenas&#39; memory. ...
std::vector< uint8_t > recent_streaming_probabilities_
bool load_model(tflite::MicroMutableOpResolver< 20 > &op_resolver)
Allocates tensor and variable arenas and sets up the model interpreter.
void reset_probabilities()
Sets all recent_streaming_probabilities to 0.
std::unique_ptr< tflite::MicroInterpreter > interpreter_
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE])
Implementation of SPI Controller mode.
Definition: a01nyub.cpp:7