Upgrade Playlist Features

This commit is contained in:
2025-12-09 17:20:01 +08:00
parent 577990de69
commit 8bd2780688
683 changed files with 91812 additions and 81260 deletions

View File

@@ -1,372 +1,372 @@
#include "mqtt_protocol.h"
#include "board.h"
#include "application.h"
#include "settings.h"
#include <esp_log.h>
#include <cstring>
#include <arpa/inet.h>
#include "assets/lang_config.h"
#define TAG "MQTT"
MqttProtocol::MqttProtocol() {
event_group_handle_ = xEventGroupCreate();
// Initialize reconnect timer
esp_timer_create_args_t reconnect_timer_args = {
.callback = [](void* arg) {
MqttProtocol* protocol = (MqttProtocol*)arg;
auto& app = Application::GetInstance();
if (app.GetDeviceState() == kDeviceStateIdle) {
ESP_LOGI(TAG, "Reconnecting to MQTT server");
app.Schedule([protocol]() {
protocol->StartMqttClient(false);
});
}
},
.arg = this,
};
esp_timer_create(&reconnect_timer_args, &reconnect_timer_);
}
MqttProtocol::~MqttProtocol() {
ESP_LOGI(TAG, "MqttProtocol deinit");
if (reconnect_timer_ != nullptr) {
esp_timer_stop(reconnect_timer_);
esp_timer_delete(reconnect_timer_);
}
udp_.reset();
mqtt_.reset();
if (event_group_handle_ != nullptr) {
vEventGroupDelete(event_group_handle_);
}
}
bool MqttProtocol::Start() {
return StartMqttClient(false);
}
bool MqttProtocol::StartMqttClient(bool report_error) {
if (mqtt_ != nullptr) {
ESP_LOGW(TAG, "Mqtt client already started");
mqtt_.reset();
}
Settings settings("mqtt", false);
auto endpoint = settings.GetString("endpoint");
auto client_id = settings.GetString("client_id");
auto username = settings.GetString("username");
auto password = settings.GetString("password");
int keepalive_interval = settings.GetInt("keepalive", 240);
publish_topic_ = settings.GetString("publish_topic");
if (endpoint.empty()) {
ESP_LOGW(TAG, "MQTT endpoint is not specified");
if (report_error) {
SetError(Lang::Strings::SERVER_NOT_FOUND);
}
return false;
}
auto network = Board::GetInstance().GetNetwork();
mqtt_ = network->CreateMqtt(0);
mqtt_->SetKeepAlive(keepalive_interval);
mqtt_->OnDisconnected([this]() {
if (on_disconnected_ != nullptr) {
on_disconnected_();
}
ESP_LOGI(TAG, "MQTT disconnected, schedule reconnect in %d seconds", MQTT_RECONNECT_INTERVAL_MS / 1000);
esp_timer_start_once(reconnect_timer_, MQTT_RECONNECT_INTERVAL_MS * 1000);
});
mqtt_->OnConnected([this]() {
if (on_connected_ != nullptr) {
on_connected_();
}
esp_timer_stop(reconnect_timer_);
});
mqtt_->OnMessage([this](const std::string& topic, const std::string& payload) {
cJSON* root = cJSON_Parse(payload.c_str());
if (root == nullptr) {
ESP_LOGE(TAG, "Failed to parse json message %s", payload.c_str());
return;
}
cJSON* type = cJSON_GetObjectItem(root, "type");
if (!cJSON_IsString(type)) {
ESP_LOGE(TAG, "Message type is invalid");
cJSON_Delete(root);
return;
}
if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root);
} else if (strcmp(type->valuestring, "goodbye") == 0) {
auto session_id = cJSON_GetObjectItem(root, "session_id");
ESP_LOGI(TAG, "Received goodbye message, session_id: %s", session_id ? session_id->valuestring : "null");
if (session_id == nullptr || session_id_ == session_id->valuestring) {
Application::GetInstance().Schedule([this]() {
CloseAudioChannel();
});
}
} else if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
cJSON_Delete(root);
last_incoming_time_ = std::chrono::steady_clock::now();
});
ESP_LOGI(TAG, "Connecting to endpoint %s", endpoint.c_str());
std::string broker_address;
int broker_port = 8883;
size_t pos = endpoint.find(':');
if (pos != std::string::npos) {
broker_address = endpoint.substr(0, pos);
broker_port = std::stoi(endpoint.substr(pos + 1));
} else {
broker_address = endpoint;
}
if (!mqtt_->Connect(broker_address, broker_port, client_id, username, password)) {
ESP_LOGE(TAG, "Failed to connect to endpoint");
SetError(Lang::Strings::SERVER_NOT_CONNECTED);
return false;
}
ESP_LOGI(TAG, "Connected to endpoint");
return true;
}
bool MqttProtocol::SendText(const std::string& text) {
if (publish_topic_.empty()) {
return false;
}
if (!mqtt_->Publish(publish_topic_, text)) {
ESP_LOGE(TAG, "Failed to publish message: %s", text.c_str());
SetError(Lang::Strings::SERVER_ERROR);
return false;
}
return true;
}
bool MqttProtocol::SendAudio(std::unique_ptr<AudioStreamPacket> packet) {
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ == nullptr) {
return false;
}
std::string nonce(aes_nonce_);
*(uint16_t*)&nonce[2] = htons(packet->payload.size());
*(uint32_t*)&nonce[8] = htonl(packet->timestamp);
*(uint32_t*)&nonce[12] = htonl(++local_sequence_);
std::string encrypted;
encrypted.resize(aes_nonce_.size() + packet->payload.size());
memcpy(encrypted.data(), nonce.data(), nonce.size());
size_t nc_off = 0;
uint8_t stream_block[16] = {0};
if (mbedtls_aes_crypt_ctr(&aes_ctx_, packet->payload.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block,
(uint8_t*)packet->payload.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) {
ESP_LOGE(TAG, "Failed to encrypt audio data");
return false;
}
return udp_->Send(encrypted) > 0;
}
void MqttProtocol::CloseAudioChannel() {
{
std::lock_guard<std::mutex> lock(channel_mutex_);
udp_.reset();
}
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"goodbye\"";
message += "}";
SendText(message);
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
}
bool MqttProtocol::OpenAudioChannel() {
if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
if (!StartMqttClient(true)) {
return false;
}
}
error_occurred_ = false;
session_id_ = "";
xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
auto message = GetHelloMessage();
if (!SendText(message)) {
return false;
}
// 等待服务器响应
EventBits_t bits = xEventGroupWaitBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
if (!(bits & MQTT_PROTOCOL_SERVER_HELLO_EVENT)) {
ESP_LOGE(TAG, "Failed to receive server hello");
SetError(Lang::Strings::SERVER_TIMEOUT);
return false;
}
std::lock_guard<std::mutex> lock(channel_mutex_);
auto network = Board::GetInstance().GetNetwork();
udp_ = network->CreateUdp(2);
udp_->OnMessage([this](const std::string& data) {
/*
* UDP Encrypted OPUS Packet Format:
* |type 1u|flags 1u|payload_len 2u|ssrc 4u|timestamp 4u|sequence 4u|
* |payload payload_len|
*/
if (data.size() < sizeof(aes_nonce_)) {
ESP_LOGE(TAG, "Invalid audio packet size: %u", data.size());
return;
}
if (data[0] != 0x01) {
ESP_LOGE(TAG, "Invalid audio packet type: %x", data[0]);
return;
}
uint32_t timestamp = ntohl(*(uint32_t*)&data[8]);
uint32_t sequence = ntohl(*(uint32_t*)&data[12]);
if (sequence < remote_sequence_) {
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
return;
}
if (sequence != remote_sequence_ + 1) {
ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
}
size_t decrypted_size = data.size() - aes_nonce_.size();
size_t nc_off = 0;
uint8_t stream_block[16] = {0};
auto nonce = (uint8_t*)data.data();
auto encrypted = (uint8_t*)data.data() + aes_nonce_.size();
auto packet = std::make_unique<AudioStreamPacket>();
packet->sample_rate = server_sample_rate_;
packet->frame_duration = server_frame_duration_;
packet->timestamp = timestamp;
packet->payload.resize(decrypted_size);
int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)packet->payload.data());
if (ret != 0) {
ESP_LOGE(TAG, "Failed to decrypt audio data, ret: %d", ret);
return;
}
if (on_incoming_audio_ != nullptr) {
on_incoming_audio_(std::move(packet));
}
remote_sequence_ = sequence;
last_incoming_time_ = std::chrono::steady_clock::now();
});
udp_->Connect(udp_server_, udp_port_);
if (on_audio_channel_opened_ != nullptr) {
on_audio_channel_opened_();
}
return true;
}
std::string MqttProtocol::GetHelloMessage() {
// 发送 hello 消息申请 UDP 通道
cJSON* root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "type", "hello");
cJSON_AddNumberToObject(root, "version", 3);
cJSON_AddStringToObject(root, "transport", "udp");
cJSON* features = cJSON_CreateObject();
#if CONFIG_USE_SERVER_AEC
cJSON_AddBoolToObject(features, "aec", true);
#endif
cJSON_AddBoolToObject(features, "mcp", true);
cJSON_AddItemToObject(root, "features", features);
cJSON* audio_params = cJSON_CreateObject();
cJSON_AddStringToObject(audio_params, "format", "opus");
cJSON_AddNumberToObject(audio_params, "sample_rate", 16000);
cJSON_AddNumberToObject(audio_params, "channels", 1);
cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS);
cJSON_AddItemToObject(root, "audio_params", audio_params);
auto json_str = cJSON_PrintUnformatted(root);
std::string message(json_str);
cJSON_free(json_str);
cJSON_Delete(root);
return message;
}
void MqttProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {
ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
return;
}
auto session_id = cJSON_GetObjectItem(root, "session_id");
if (cJSON_IsString(session_id)) {
session_id_ = session_id->valuestring;
ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
}
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (cJSON_IsObject(audio_params)) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (cJSON_IsNumber(sample_rate)) {
server_sample_rate_ = sample_rate->valueint;
}
auto frame_duration = cJSON_GetObjectItem(audio_params, "frame_duration");
if (cJSON_IsNumber(frame_duration)) {
server_frame_duration_ = frame_duration->valueint;
}
}
auto udp = cJSON_GetObjectItem(root, "udp");
if (!cJSON_IsObject(udp)) {
ESP_LOGE(TAG, "UDP is not specified");
return;
}
udp_server_ = cJSON_GetObjectItem(udp, "server")->valuestring;
udp_port_ = cJSON_GetObjectItem(udp, "port")->valueint;
auto key = cJSON_GetObjectItem(udp, "key")->valuestring;
auto nonce = cJSON_GetObjectItem(udp, "nonce")->valuestring;
// auto encryption = cJSON_GetObjectItem(udp, "encryption")->valuestring;
// ESP_LOGI(TAG, "UDP server: %s, port: %d, encryption: %s", udp_server_.c_str(), udp_port_, encryption);
aes_nonce_ = DecodeHexString(nonce);
mbedtls_aes_init(&aes_ctx_);
mbedtls_aes_setkey_enc(&aes_ctx_, (const unsigned char*)DecodeHexString(key).c_str(), 128);
local_sequence_ = 0;
remote_sequence_ = 0;
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
}
static const char hex_chars[] = "0123456789ABCDEF";
// 辅助函数,将单个十六进制字符转换为对应的数值
static inline uint8_t CharToHex(char c) {
if (c >= '0' && c <= '9') return c - '0';
if (c >= 'A' && c <= 'F') return c - 'A' + 10;
if (c >= 'a' && c <= 'f') return c - 'a' + 10;
return 0; // 对于无效输入返回0
}
std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
std::string decoded;
decoded.reserve(hex_string.size() / 2);
for (size_t i = 0; i < hex_string.size(); i += 2) {
char byte = (CharToHex(hex_string[i]) << 4) | CharToHex(hex_string[i + 1]);
decoded.push_back(byte);
}
return decoded;
}
bool MqttProtocol::IsAudioChannelOpened() const {
return udp_ != nullptr && !error_occurred_ && !IsTimeout();
}
#include "mqtt_protocol.h"
#include "board.h"
#include "application.h"
#include "settings.h"
#include <esp_log.h>
#include <cstring>
#include <arpa/inet.h>
#include "assets/lang_config.h"
#define TAG "MQTT"
MqttProtocol::MqttProtocol() {
event_group_handle_ = xEventGroupCreate();
// Initialize reconnect timer
esp_timer_create_args_t reconnect_timer_args = {
.callback = [](void* arg) {
MqttProtocol* protocol = (MqttProtocol*)arg;
auto& app = Application::GetInstance();
if (app.GetDeviceState() == kDeviceStateIdle) {
ESP_LOGI(TAG, "Reconnecting to MQTT server");
app.Schedule([protocol]() {
protocol->StartMqttClient(false);
});
}
},
.arg = this,
};
esp_timer_create(&reconnect_timer_args, &reconnect_timer_);
}
MqttProtocol::~MqttProtocol() {
ESP_LOGI(TAG, "MqttProtocol deinit");
if (reconnect_timer_ != nullptr) {
esp_timer_stop(reconnect_timer_);
esp_timer_delete(reconnect_timer_);
}
udp_.reset();
mqtt_.reset();
if (event_group_handle_ != nullptr) {
vEventGroupDelete(event_group_handle_);
}
}
bool MqttProtocol::Start() {
return StartMqttClient(false);
}
bool MqttProtocol::StartMqttClient(bool report_error) {
if (mqtt_ != nullptr) {
ESP_LOGW(TAG, "Mqtt client already started");
mqtt_.reset();
}
Settings settings("mqtt", false);
auto endpoint = settings.GetString("endpoint");
auto client_id = settings.GetString("client_id");
auto username = settings.GetString("username");
auto password = settings.GetString("password");
int keepalive_interval = settings.GetInt("keepalive", 240);
publish_topic_ = settings.GetString("publish_topic");
if (endpoint.empty()) {
ESP_LOGW(TAG, "MQTT endpoint is not specified");
if (report_error) {
SetError(Lang::Strings::SERVER_NOT_FOUND);
}
return false;
}
auto network = Board::GetInstance().GetNetwork();
mqtt_ = network->CreateMqtt(0);
mqtt_->SetKeepAlive(keepalive_interval);
mqtt_->OnDisconnected([this]() {
if (on_disconnected_ != nullptr) {
on_disconnected_();
}
ESP_LOGI(TAG, "MQTT disconnected, schedule reconnect in %d seconds", MQTT_RECONNECT_INTERVAL_MS / 1000);
esp_timer_start_once(reconnect_timer_, MQTT_RECONNECT_INTERVAL_MS * 1000);
});
mqtt_->OnConnected([this]() {
if (on_connected_ != nullptr) {
on_connected_();
}
esp_timer_stop(reconnect_timer_);
});
mqtt_->OnMessage([this](const std::string& topic, const std::string& payload) {
cJSON* root = cJSON_Parse(payload.c_str());
if (root == nullptr) {
ESP_LOGE(TAG, "Failed to parse json message %s", payload.c_str());
return;
}
cJSON* type = cJSON_GetObjectItem(root, "type");
if (!cJSON_IsString(type)) {
ESP_LOGE(TAG, "Message type is invalid");
cJSON_Delete(root);
return;
}
if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root);
} else if (strcmp(type->valuestring, "goodbye") == 0) {
auto session_id = cJSON_GetObjectItem(root, "session_id");
ESP_LOGI(TAG, "Received goodbye message, session_id: %s", session_id ? session_id->valuestring : "null");
if (session_id == nullptr || session_id_ == session_id->valuestring) {
Application::GetInstance().Schedule([this]() {
CloseAudioChannel();
});
}
} else if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
cJSON_Delete(root);
last_incoming_time_ = std::chrono::steady_clock::now();
});
ESP_LOGI(TAG, "Connecting to endpoint %s", endpoint.c_str());
std::string broker_address;
int broker_port = 8883;
size_t pos = endpoint.find(':');
if (pos != std::string::npos) {
broker_address = endpoint.substr(0, pos);
broker_port = std::stoi(endpoint.substr(pos + 1));
} else {
broker_address = endpoint;
}
if (!mqtt_->Connect(broker_address, broker_port, client_id, username, password)) {
ESP_LOGE(TAG, "Failed to connect to endpoint");
SetError(Lang::Strings::SERVER_NOT_CONNECTED);
return false;
}
ESP_LOGI(TAG, "Connected to endpoint");
return true;
}
bool MqttProtocol::SendText(const std::string& text) {
if (publish_topic_.empty()) {
return false;
}
if (!mqtt_->Publish(publish_topic_, text)) {
ESP_LOGE(TAG, "Failed to publish message: %s", text.c_str());
SetError(Lang::Strings::SERVER_ERROR);
return false;
}
return true;
}
bool MqttProtocol::SendAudio(std::unique_ptr<AudioStreamPacket> packet) {
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ == nullptr) {
return false;
}
std::string nonce(aes_nonce_);
*(uint16_t*)&nonce[2] = htons(packet->payload.size());
*(uint32_t*)&nonce[8] = htonl(packet->timestamp);
*(uint32_t*)&nonce[12] = htonl(++local_sequence_);
std::string encrypted;
encrypted.resize(aes_nonce_.size() + packet->payload.size());
memcpy(encrypted.data(), nonce.data(), nonce.size());
size_t nc_off = 0;
uint8_t stream_block[16] = {0};
if (mbedtls_aes_crypt_ctr(&aes_ctx_, packet->payload.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block,
(uint8_t*)packet->payload.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) {
ESP_LOGE(TAG, "Failed to encrypt audio data");
return false;
}
return udp_->Send(encrypted) > 0;
}
void MqttProtocol::CloseAudioChannel() {
{
std::lock_guard<std::mutex> lock(channel_mutex_);
udp_.reset();
}
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"goodbye\"";
message += "}";
SendText(message);
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
}
bool MqttProtocol::OpenAudioChannel() {
if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
if (!StartMqttClient(true)) {
return false;
}
}
error_occurred_ = false;
session_id_ = "";
xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
auto message = GetHelloMessage();
if (!SendText(message)) {
return false;
}
// 等待服务器响应
EventBits_t bits = xEventGroupWaitBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
if (!(bits & MQTT_PROTOCOL_SERVER_HELLO_EVENT)) {
ESP_LOGE(TAG, "Failed to receive server hello");
SetError(Lang::Strings::SERVER_TIMEOUT);
return false;
}
std::lock_guard<std::mutex> lock(channel_mutex_);
auto network = Board::GetInstance().GetNetwork();
udp_ = network->CreateUdp(2);
udp_->OnMessage([this](const std::string& data) {
/*
* UDP Encrypted OPUS Packet Format:
* |type 1u|flags 1u|payload_len 2u|ssrc 4u|timestamp 4u|sequence 4u|
* |payload payload_len|
*/
if (data.size() < sizeof(aes_nonce_)) {
ESP_LOGE(TAG, "Invalid audio packet size: %u", data.size());
return;
}
if (data[0] != 0x01) {
ESP_LOGE(TAG, "Invalid audio packet type: %x", data[0]);
return;
}
uint32_t timestamp = ntohl(*(uint32_t*)&data[8]);
uint32_t sequence = ntohl(*(uint32_t*)&data[12]);
if (sequence < remote_sequence_) {
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
return;
}
if (sequence != remote_sequence_ + 1) {
ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
}
size_t decrypted_size = data.size() - aes_nonce_.size();
size_t nc_off = 0;
uint8_t stream_block[16] = {0};
auto nonce = (uint8_t*)data.data();
auto encrypted = (uint8_t*)data.data() + aes_nonce_.size();
auto packet = std::make_unique<AudioStreamPacket>();
packet->sample_rate = server_sample_rate_;
packet->frame_duration = server_frame_duration_;
packet->timestamp = timestamp;
packet->payload.resize(decrypted_size);
int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)packet->payload.data());
if (ret != 0) {
ESP_LOGE(TAG, "Failed to decrypt audio data, ret: %d", ret);
return;
}
if (on_incoming_audio_ != nullptr) {
on_incoming_audio_(std::move(packet));
}
remote_sequence_ = sequence;
last_incoming_time_ = std::chrono::steady_clock::now();
});
udp_->Connect(udp_server_, udp_port_);
if (on_audio_channel_opened_ != nullptr) {
on_audio_channel_opened_();
}
return true;
}
std::string MqttProtocol::GetHelloMessage() {
// 发送 hello 消息申请 UDP 通道
cJSON* root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "type", "hello");
cJSON_AddNumberToObject(root, "version", 3);
cJSON_AddStringToObject(root, "transport", "udp");
cJSON* features = cJSON_CreateObject();
#if CONFIG_USE_SERVER_AEC
cJSON_AddBoolToObject(features, "aec", true);
#endif
cJSON_AddBoolToObject(features, "mcp", true);
cJSON_AddItemToObject(root, "features", features);
cJSON* audio_params = cJSON_CreateObject();
cJSON_AddStringToObject(audio_params, "format", "opus");
cJSON_AddNumberToObject(audio_params, "sample_rate", 16000);
cJSON_AddNumberToObject(audio_params, "channels", 1);
cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS);
cJSON_AddItemToObject(root, "audio_params", audio_params);
auto json_str = cJSON_PrintUnformatted(root);
std::string message(json_str);
cJSON_free(json_str);
cJSON_Delete(root);
return message;
}
void MqttProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {
ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
return;
}
auto session_id = cJSON_GetObjectItem(root, "session_id");
if (cJSON_IsString(session_id)) {
session_id_ = session_id->valuestring;
ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
}
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (cJSON_IsObject(audio_params)) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (cJSON_IsNumber(sample_rate)) {
server_sample_rate_ = sample_rate->valueint;
}
auto frame_duration = cJSON_GetObjectItem(audio_params, "frame_duration");
if (cJSON_IsNumber(frame_duration)) {
server_frame_duration_ = frame_duration->valueint;
}
}
auto udp = cJSON_GetObjectItem(root, "udp");
if (!cJSON_IsObject(udp)) {
ESP_LOGE(TAG, "UDP is not specified");
return;
}
udp_server_ = cJSON_GetObjectItem(udp, "server")->valuestring;
udp_port_ = cJSON_GetObjectItem(udp, "port")->valueint;
auto key = cJSON_GetObjectItem(udp, "key")->valuestring;
auto nonce = cJSON_GetObjectItem(udp, "nonce")->valuestring;
// auto encryption = cJSON_GetObjectItem(udp, "encryption")->valuestring;
// ESP_LOGI(TAG, "UDP server: %s, port: %d, encryption: %s", udp_server_.c_str(), udp_port_, encryption);
aes_nonce_ = DecodeHexString(nonce);
mbedtls_aes_init(&aes_ctx_);
mbedtls_aes_setkey_enc(&aes_ctx_, (const unsigned char*)DecodeHexString(key).c_str(), 128);
local_sequence_ = 0;
remote_sequence_ = 0;
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
}
static const char hex_chars[] = "0123456789ABCDEF";
// 辅助函数,将单个十六进制字符转换为对应的数值
static inline uint8_t CharToHex(char c) {
if (c >= '0' && c <= '9') return c - '0';
if (c >= 'A' && c <= 'F') return c - 'A' + 10;
if (c >= 'a' && c <= 'f') return c - 'a' + 10;
return 0; // 对于无效输入返回0
}
std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
std::string decoded;
decoded.reserve(hex_string.size() / 2);
for (size_t i = 0; i < hex_string.size(); i += 2) {
char byte = (CharToHex(hex_string[i]) << 4) | CharToHex(hex_string[i + 1]);
decoded.push_back(byte);
}
return decoded;
}
bool MqttProtocol::IsAudioChannelOpened() const {
return udp_ != nullptr && !error_occurred_ && !IsTimeout();
}

View File

@@ -1,60 +1,60 @@
#ifndef MQTT_PROTOCOL_H
#define MQTT_PROTOCOL_H
#include "protocol.h"
#include <mqtt.h>
#include <udp.h>
#include <cJSON.h>
#include <mbedtls/aes.h>
#include <freertos/FreeRTOS.h>
#include <freertos/event_groups.h>
#include <esp_timer.h>
#include <functional>
#include <string>
#include <map>
#include <mutex>
#define MQTT_PING_INTERVAL_SECONDS 90
#define MQTT_RECONNECT_INTERVAL_MS 60000
#define MQTT_PROTOCOL_SERVER_HELLO_EVENT (1 << 0)
class MqttProtocol : public Protocol {
public:
MqttProtocol();
~MqttProtocol();
bool Start() override;
bool SendAudio(std::unique_ptr<AudioStreamPacket> packet) override;
bool OpenAudioChannel() override;
void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override;
private:
EventGroupHandle_t event_group_handle_;
std::string publish_topic_;
std::mutex channel_mutex_;
std::unique_ptr<Mqtt> mqtt_;
std::unique_ptr<Udp> udp_;
mbedtls_aes_context aes_ctx_;
std::string aes_nonce_;
std::string udp_server_;
int udp_port_;
uint32_t local_sequence_;
uint32_t remote_sequence_;
esp_timer_handle_t reconnect_timer_;
bool StartMqttClient(bool report_error=false);
void ParseServerHello(const cJSON* root);
std::string DecodeHexString(const std::string& hex_string);
bool SendText(const std::string& text) override;
std::string GetHelloMessage();
};
#endif // MQTT_PROTOCOL_H
#ifndef MQTT_PROTOCOL_H
#define MQTT_PROTOCOL_H
#include "protocol.h"
#include <mqtt.h>
#include <udp.h>
#include <cJSON.h>
#include <mbedtls/aes.h>
#include <freertos/FreeRTOS.h>
#include <freertos/event_groups.h>
#include <esp_timer.h>
#include <functional>
#include <string>
#include <map>
#include <mutex>
#define MQTT_PING_INTERVAL_SECONDS 90
#define MQTT_RECONNECT_INTERVAL_MS 60000
#define MQTT_PROTOCOL_SERVER_HELLO_EVENT (1 << 0)
class MqttProtocol : public Protocol {
public:
MqttProtocol();
~MqttProtocol();
bool Start() override;
bool SendAudio(std::unique_ptr<AudioStreamPacket> packet) override;
bool OpenAudioChannel() override;
void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override;
private:
EventGroupHandle_t event_group_handle_;
std::string publish_topic_;
std::mutex channel_mutex_;
std::unique_ptr<Mqtt> mqtt_;
std::unique_ptr<Udp> udp_;
mbedtls_aes_context aes_ctx_;
std::string aes_nonce_;
std::string udp_server_;
int udp_port_;
uint32_t local_sequence_;
uint32_t remote_sequence_;
esp_timer_handle_t reconnect_timer_;
bool StartMqttClient(bool report_error=false);
void ParseServerHello(const cJSON* root);
std::string DecodeHexString(const std::string& hex_string);
bool SendText(const std::string& text) override;
std::string GetHelloMessage();
};
#endif // MQTT_PROTOCOL_H

View File

@@ -1,90 +1,90 @@
#include "protocol.h"
#include <esp_log.h>
#define TAG "Protocol"
void Protocol::OnIncomingJson(std::function<void(const cJSON* root)> callback) {
on_incoming_json_ = callback;
}
void Protocol::OnIncomingAudio(std::function<void(std::unique_ptr<AudioStreamPacket> packet)> callback) {
on_incoming_audio_ = callback;
}
void Protocol::OnAudioChannelOpened(std::function<void()> callback) {
on_audio_channel_opened_ = callback;
}
void Protocol::OnAudioChannelClosed(std::function<void()> callback) {
on_audio_channel_closed_ = callback;
}
void Protocol::OnNetworkError(std::function<void(const std::string& message)> callback) {
on_network_error_ = callback;
}
void Protocol::OnConnected(std::function<void()> callback) {
on_connected_ = callback;
}
void Protocol::OnDisconnected(std::function<void()> callback) {
on_disconnected_ = callback;
}
void Protocol::SetError(const std::string& message) {
error_occurred_ = true;
if (on_network_error_ != nullptr) {
on_network_error_(message);
}
}
void Protocol::SendAbortSpeaking(AbortReason reason) {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"abort\"";
if (reason == kAbortReasonWakeWordDetected) {
message += ",\"reason\":\"wake_word_detected\"";
}
message += "}";
SendText(message);
}
void Protocol::SendWakeWordDetected(const std::string& wake_word) {
std::string json = "{\"session_id\":\"" + session_id_ +
"\",\"type\":\"listen\",\"state\":\"detect\",\"text\":\"" + wake_word + "\"}";
SendText(json);
}
void Protocol::SendStartListening(ListeningMode mode) {
std::string message = "{\"session_id\":\"" + session_id_ + "\"";
message += ",\"type\":\"listen\",\"state\":\"start\"";
if (mode == kListeningModeRealtime) {
message += ",\"mode\":\"realtime\"";
} else if (mode == kListeningModeAutoStop) {
message += ",\"mode\":\"auto\"";
} else {
message += ",\"mode\":\"manual\"";
}
message += "}";
SendText(message);
}
void Protocol::SendStopListening() {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"listen\",\"state\":\"stop\"}";
SendText(message);
}
void Protocol::SendMcpMessage(const std::string& payload) {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"mcp\",\"payload\":" + payload + "}";
SendText(message);
}
bool Protocol::IsTimeout() const {
const int kTimeoutSeconds = 120;
auto now = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>(now - last_incoming_time_);
bool timeout = duration.count() > kTimeoutSeconds;
if (timeout) {
ESP_LOGE(TAG, "Channel timeout %ld seconds", (long)duration.count());
}
return timeout;
}
#include "protocol.h"
#include <esp_log.h>
#define TAG "Protocol"
void Protocol::OnIncomingJson(std::function<void(const cJSON* root)> callback) {
on_incoming_json_ = callback;
}
void Protocol::OnIncomingAudio(std::function<void(std::unique_ptr<AudioStreamPacket> packet)> callback) {
on_incoming_audio_ = callback;
}
void Protocol::OnAudioChannelOpened(std::function<void()> callback) {
on_audio_channel_opened_ = callback;
}
void Protocol::OnAudioChannelClosed(std::function<void()> callback) {
on_audio_channel_closed_ = callback;
}
void Protocol::OnNetworkError(std::function<void(const std::string& message)> callback) {
on_network_error_ = callback;
}
void Protocol::OnConnected(std::function<void()> callback) {
on_connected_ = callback;
}
void Protocol::OnDisconnected(std::function<void()> callback) {
on_disconnected_ = callback;
}
void Protocol::SetError(const std::string& message) {
error_occurred_ = true;
if (on_network_error_ != nullptr) {
on_network_error_(message);
}
}
void Protocol::SendAbortSpeaking(AbortReason reason) {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"abort\"";
if (reason == kAbortReasonWakeWordDetected) {
message += ",\"reason\":\"wake_word_detected\"";
}
message += "}";
SendText(message);
}
void Protocol::SendWakeWordDetected(const std::string& wake_word) {
std::string json = "{\"session_id\":\"" + session_id_ +
"\",\"type\":\"listen\",\"state\":\"detect\",\"text\":\"" + wake_word + "\"}";
SendText(json);
}
void Protocol::SendStartListening(ListeningMode mode) {
std::string message = "{\"session_id\":\"" + session_id_ + "\"";
message += ",\"type\":\"listen\",\"state\":\"start\"";
if (mode == kListeningModeRealtime) {
message += ",\"mode\":\"realtime\"";
} else if (mode == kListeningModeAutoStop) {
message += ",\"mode\":\"auto\"";
} else {
message += ",\"mode\":\"manual\"";
}
message += "}";
SendText(message);
}
void Protocol::SendStopListening() {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"listen\",\"state\":\"stop\"}";
SendText(message);
}
void Protocol::SendMcpMessage(const std::string& payload) {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"mcp\",\"payload\":" + payload + "}";
SendText(message);
}
bool Protocol::IsTimeout() const {
const int kTimeoutSeconds = 120;
auto now = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>(now - last_incoming_time_);
bool timeout = duration.count() > kTimeoutSeconds;
if (timeout) {
ESP_LOGE(TAG, "Channel timeout %ld seconds", (long)duration.count());
}
return timeout;
}

View File

@@ -1,98 +1,98 @@
#ifndef PROTOCOL_H
#define PROTOCOL_H
#include <cJSON.h>
#include <string>
#include <functional>
#include <chrono>
#include <vector>
struct AudioStreamPacket {
int sample_rate = 0;
int frame_duration = 0;
uint32_t timestamp = 0;
std::vector<uint8_t> payload;
};
struct BinaryProtocol2 {
uint16_t version;
uint16_t type; // Message type (0: OPUS, 1: JSON)
uint32_t reserved; // Reserved for future use
uint32_t timestamp; // Timestamp in milliseconds (used for server-side AEC)
uint32_t payload_size; // Payload size in bytes
uint8_t payload[]; // Payload data
} __attribute__((packed));
struct BinaryProtocol3 {
uint8_t type;
uint8_t reserved;
uint16_t payload_size;
uint8_t payload[];
} __attribute__((packed));
enum AbortReason {
kAbortReasonNone,
kAbortReasonWakeWordDetected
};
enum ListeningMode {
kListeningModeAutoStop,
kListeningModeManualStop,
kListeningModeRealtime // 需要 AEC 支持
};
class Protocol {
public:
virtual ~Protocol() = default;
inline int server_sample_rate() const {
return server_sample_rate_;
}
inline int server_frame_duration() const {
return server_frame_duration_;
}
inline const std::string& session_id() const {
return session_id_;
}
void OnIncomingAudio(std::function<void(std::unique_ptr<AudioStreamPacket> packet)> callback);
void OnIncomingJson(std::function<void(const cJSON* root)> callback);
void OnAudioChannelOpened(std::function<void()> callback);
void OnAudioChannelClosed(std::function<void()> callback);
void OnNetworkError(std::function<void(const std::string& message)> callback);
void OnConnected(std::function<void()> callback);
void OnDisconnected(std::function<void()> callback);
virtual bool Start() = 0;
virtual bool OpenAudioChannel() = 0;
virtual void CloseAudioChannel() = 0;
virtual bool IsAudioChannelOpened() const = 0;
virtual bool SendAudio(std::unique_ptr<AudioStreamPacket> packet) = 0;
virtual void SendWakeWordDetected(const std::string& wake_word);
virtual void SendStartListening(ListeningMode mode);
virtual void SendStopListening();
virtual void SendAbortSpeaking(AbortReason reason);
virtual void SendMcpMessage(const std::string& message);
protected:
std::function<void(const cJSON* root)> on_incoming_json_;
std::function<void(std::unique_ptr<AudioStreamPacket> packet)> on_incoming_audio_;
std::function<void()> on_audio_channel_opened_;
std::function<void()> on_audio_channel_closed_;
std::function<void(const std::string& message)> on_network_error_;
std::function<void()> on_connected_;
std::function<void()> on_disconnected_;
int server_sample_rate_ = 24000;
int server_frame_duration_ = 60;
bool error_occurred_ = false;
std::string session_id_;
std::chrono::time_point<std::chrono::steady_clock> last_incoming_time_;
virtual bool SendText(const std::string& text) = 0;
virtual void SetError(const std::string& message);
virtual bool IsTimeout() const;
};
#endif // PROTOCOL_H
#ifndef PROTOCOL_H
#define PROTOCOL_H
#include <cJSON.h>
#include <string>
#include <functional>
#include <chrono>
#include <vector>
struct AudioStreamPacket {
int sample_rate = 0;
int frame_duration = 0;
uint32_t timestamp = 0;
std::vector<uint8_t> payload;
};
struct BinaryProtocol2 {
uint16_t version;
uint16_t type; // Message type (0: OPUS, 1: JSON)
uint32_t reserved; // Reserved for future use
uint32_t timestamp; // Timestamp in milliseconds (used for server-side AEC)
uint32_t payload_size; // Payload size in bytes
uint8_t payload[]; // Payload data
} __attribute__((packed));
struct BinaryProtocol3 {
uint8_t type;
uint8_t reserved;
uint16_t payload_size;
uint8_t payload[];
} __attribute__((packed));
enum AbortReason {
kAbortReasonNone,
kAbortReasonWakeWordDetected
};
enum ListeningMode {
kListeningModeAutoStop,
kListeningModeManualStop,
kListeningModeRealtime // 需要 AEC 支持
};
class Protocol {
public:
virtual ~Protocol() = default;
inline int server_sample_rate() const {
return server_sample_rate_;
}
inline int server_frame_duration() const {
return server_frame_duration_;
}
inline const std::string& session_id() const {
return session_id_;
}
void OnIncomingAudio(std::function<void(std::unique_ptr<AudioStreamPacket> packet)> callback);
void OnIncomingJson(std::function<void(const cJSON* root)> callback);
void OnAudioChannelOpened(std::function<void()> callback);
void OnAudioChannelClosed(std::function<void()> callback);
void OnNetworkError(std::function<void(const std::string& message)> callback);
void OnConnected(std::function<void()> callback);
void OnDisconnected(std::function<void()> callback);
virtual bool Start() = 0;
virtual bool OpenAudioChannel() = 0;
virtual void CloseAudioChannel() = 0;
virtual bool IsAudioChannelOpened() const = 0;
virtual bool SendAudio(std::unique_ptr<AudioStreamPacket> packet) = 0;
virtual void SendWakeWordDetected(const std::string& wake_word);
virtual void SendStartListening(ListeningMode mode);
virtual void SendStopListening();
virtual void SendAbortSpeaking(AbortReason reason);
virtual void SendMcpMessage(const std::string& message);
protected:
std::function<void(const cJSON* root)> on_incoming_json_;
std::function<void(std::unique_ptr<AudioStreamPacket> packet)> on_incoming_audio_;
std::function<void()> on_audio_channel_opened_;
std::function<void()> on_audio_channel_closed_;
std::function<void(const std::string& message)> on_network_error_;
std::function<void()> on_connected_;
std::function<void()> on_disconnected_;
int server_sample_rate_ = 24000;
int server_frame_duration_ = 60;
bool error_occurred_ = false;
std::string session_id_;
std::chrono::time_point<std::chrono::steady_clock> last_incoming_time_;
virtual bool SendText(const std::string& text) = 0;
virtual void SetError(const std::string& message);
virtual bool IsTimeout() const;
};
#endif // PROTOCOL_H

View File

@@ -1,105 +0,0 @@
#include "sleep_music_protocol.h"
#include "board.h"
#include "application.h"
#include "protocol.h"
#include <cstring>
#include <esp_log.h>
#define TAG "SleepMusic"
SleepMusicProtocol& SleepMusicProtocol::GetInstance() {
static SleepMusicProtocol instance;
return instance;
}
SleepMusicProtocol::SleepMusicProtocol() {
event_group_handle_ = xEventGroupCreate();
}
SleepMusicProtocol::~SleepMusicProtocol() {
vEventGroupDelete(event_group_handle_);
}
bool SleepMusicProtocol::IsAudioChannelOpened() const {
return is_connected_ && websocket_ != nullptr && websocket_->IsConnected();
}
void SleepMusicProtocol::CloseAudioChannel() {
if (websocket_) {
ESP_LOGI(TAG, "Closing sleep music audio channel");
// 清理状态
is_connected_ = false;
// 关闭WebSocket连接
websocket_.reset();
ESP_LOGI(TAG, "Sleep music audio channel closed");
}
}
bool SleepMusicProtocol::OpenAudioChannel() {
std::string url = "ws://180.76.190.230:8765";
ESP_LOGI(TAG, "Connecting to sleep music server: %s", url.c_str());
auto network = Board::GetInstance().GetNetwork();
websocket_ = network->CreateWebSocket(2); // 使用不同的WebSocket实例ID
if (websocket_ == nullptr) {
ESP_LOGE(TAG, "Failed to create websocket for sleep music");
return false;
}
// 设置WebSocket数据接收回调
websocket_->OnData([this](const char* data, size_t len, bool binary) {
if (binary) {
// 接收到的二进制数据是OPUS编码的音频帧
OnAudioDataReceived(data, len);
} else {
ESP_LOGW(TAG, "Received non-binary data from sleep music server, ignoring");
}
});
websocket_->OnDisconnected([this]() {
ESP_LOGI(TAG, "Sleep music websocket disconnected");
});
// 连接到睡眠音乐服务器
if (!websocket_->Connect(url.c_str())) {
ESP_LOGE(TAG, "Failed to connect to sleep music server");
return false;
}
// 设置连接成功事件
xEventGroupSetBits(event_group_handle_, SLEEP_MUSIC_PROTOCOL_CONNECTED_EVENT);
ESP_LOGI(TAG, "Successfully connected to sleep music server");
is_connected_ = true;
return true;
}
void SleepMusicProtocol::OnAudioDataReceived(const char* data, size_t len) {
if (len == 0) {
ESP_LOGW(TAG, "Received empty audio data");
return;
}
ESP_LOGD(TAG, "Received audio frame: %zu bytes", len);
// 创建AudioStreamPacket
auto packet = std::make_unique<AudioStreamPacket>();
packet->sample_rate = SAMPLE_RATE;
packet->frame_duration = FRAME_DURATION_MS;
packet->timestamp = 0; // 睡眠音乐不需要时间戳同步
packet->payload.resize(len);
std::memcpy(packet->payload.data(), data, len);
// 将音频包推入解码队列
auto& app = Application::GetInstance();
auto& audio_service = app.GetAudioService();
if (!audio_service.PushPacketToDecodeQueue(std::move(packet), false)) {
ESP_LOGW(TAG, "Audio decode queue is full, dropping packet");
}
}

View File

@@ -1,35 +0,0 @@
#ifndef _SLEEP_MUSIC_PROTOCOL_H_
#define _SLEEP_MUSIC_PROTOCOL_H_
#include <web_socket.h>
#include <freertos/FreeRTOS.h>
#include <freertos/event_groups.h>
#include <memory>
#define SLEEP_MUSIC_PROTOCOL_CONNECTED_EVENT (1 << 0)
class SleepMusicProtocol {
public:
static SleepMusicProtocol& GetInstance();
bool OpenAudioChannel();
void CloseAudioChannel();
bool IsAudioChannelOpened() const;
private:
SleepMusicProtocol();
~SleepMusicProtocol();
EventGroupHandle_t event_group_handle_;
std::unique_ptr<WebSocket> websocket_;
bool is_connected_ = false;
// 睡眠音乐服务器配置
static constexpr int SAMPLE_RATE = 24000; // 24kHz
static constexpr int CHANNELS = 2; // 立体声
static constexpr int FRAME_DURATION_MS = 60; // 60ms帧时长
void OnAudioDataReceived(const char* data, size_t len);
};
#endif

View File

@@ -1,253 +1,253 @@
#include "websocket_protocol.h"
#include "board.h"
#include "system_info.h"
#include "application.h"
#include "settings.h"
#include <cstring>
#include <cJSON.h>
#include <esp_log.h>
#include <arpa/inet.h>
#include "assets/lang_config.h"
#define TAG "WS"
WebsocketProtocol::WebsocketProtocol() {
event_group_handle_ = xEventGroupCreate();
}
WebsocketProtocol::~WebsocketProtocol() {
vEventGroupDelete(event_group_handle_);
}
bool WebsocketProtocol::Start() {
// Only connect to server when audio channel is needed
return true;
}
bool WebsocketProtocol::SendAudio(std::unique_ptr<AudioStreamPacket> packet) {
if (websocket_ == nullptr || !websocket_->IsConnected()) {
return false;
}
if (version_ == 2) {
std::string serialized;
serialized.resize(sizeof(BinaryProtocol2) + packet->payload.size());
auto bp2 = (BinaryProtocol2*)serialized.data();
bp2->version = htons(version_);
bp2->type = 0;
bp2->reserved = 0;
bp2->timestamp = htonl(packet->timestamp);
bp2->payload_size = htonl(packet->payload.size());
memcpy(bp2->payload, packet->payload.data(), packet->payload.size());
return websocket_->Send(serialized.data(), serialized.size(), true);
} else if (version_ == 3) {
std::string serialized;
serialized.resize(sizeof(BinaryProtocol3) + packet->payload.size());
auto bp3 = (BinaryProtocol3*)serialized.data();
bp3->type = 0;
bp3->reserved = 0;
bp3->payload_size = htons(packet->payload.size());
memcpy(bp3->payload, packet->payload.data(), packet->payload.size());
return websocket_->Send(serialized.data(), serialized.size(), true);
} else {
return websocket_->Send(packet->payload.data(), packet->payload.size(), true);
}
}
bool WebsocketProtocol::SendText(const std::string& text) {
if (websocket_ == nullptr || !websocket_->IsConnected()) {
return false;
}
if (!websocket_->Send(text)) {
ESP_LOGE(TAG, "Failed to send text: %s", text.c_str());
SetError(Lang::Strings::SERVER_ERROR);
return false;
}
return true;
}
bool WebsocketProtocol::IsAudioChannelOpened() const {
return websocket_ != nullptr && websocket_->IsConnected() && !error_occurred_ && !IsTimeout();
}
void WebsocketProtocol::CloseAudioChannel() {
websocket_.reset();
}
bool WebsocketProtocol::OpenAudioChannel() {
Settings settings("websocket", false);
std::string url = settings.GetString("url");
std::string token = settings.GetString("token");
int version = settings.GetInt("version");
if (version != 0) {
version_ = version;
}
error_occurred_ = false;
auto network = Board::GetInstance().GetNetwork();
websocket_ = network->CreateWebSocket(1);
if (websocket_ == nullptr) {
ESP_LOGE(TAG, "Failed to create websocket");
return false;
}
if (!token.empty()) {
// If token not has a space, add "Bearer " prefix
if (token.find(" ") == std::string::npos) {
token = "Bearer " + token;
}
websocket_->SetHeader("Authorization", token.c_str());
}
websocket_->SetHeader("Protocol-Version", std::to_string(version_).c_str());
websocket_->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str());
websocket_->SetHeader("Client-Id", Board::GetInstance().GetUuid().c_str());
websocket_->OnData([this](const char* data, size_t len, bool binary) {
if (binary) {
if (on_incoming_audio_ != nullptr) {
if (version_ == 2) {
BinaryProtocol2* bp2 = (BinaryProtocol2*)data;
bp2->version = ntohs(bp2->version);
bp2->type = ntohs(bp2->type);
bp2->timestamp = ntohl(bp2->timestamp);
bp2->payload_size = ntohl(bp2->payload_size);
auto payload = (uint8_t*)bp2->payload;
on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
.sample_rate = server_sample_rate_,
.frame_duration = server_frame_duration_,
.timestamp = bp2->timestamp,
.payload = std::vector<uint8_t>(payload, payload + bp2->payload_size)
}));
} else if (version_ == 3) {
BinaryProtocol3* bp3 = (BinaryProtocol3*)data;
bp3->type = bp3->type;
bp3->payload_size = ntohs(bp3->payload_size);
auto payload = (uint8_t*)bp3->payload;
on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
.sample_rate = server_sample_rate_,
.frame_duration = server_frame_duration_,
.timestamp = 0,
.payload = std::vector<uint8_t>(payload, payload + bp3->payload_size)
}));
} else {
on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
.sample_rate = server_sample_rate_,
.frame_duration = server_frame_duration_,
.timestamp = 0,
.payload = std::vector<uint8_t>((uint8_t*)data, (uint8_t*)data + len)
}));
}
}
} else {
// Parse JSON data
auto root = cJSON_Parse(data);
auto type = cJSON_GetObjectItem(root, "type");
if (cJSON_IsString(type)) {
if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root);
} else {
if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
}
} else {
ESP_LOGE(TAG, "Missing message type, data: %s", data);
}
cJSON_Delete(root);
}
last_incoming_time_ = std::chrono::steady_clock::now();
});
websocket_->OnDisconnected([this]() {
ESP_LOGI(TAG, "Websocket disconnected");
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
});
ESP_LOGI(TAG, "Connecting to websocket server: %s with version: %d", url.c_str(), version_);
if (!websocket_->Connect(url.c_str())) {
ESP_LOGE(TAG, "Failed to connect to websocket server");
SetError(Lang::Strings::SERVER_NOT_CONNECTED);
return false;
}
// Send hello message to describe the client
auto message = GetHelloMessage();
if (!SendText(message)) {
return false;
}
// Wait for server hello
EventBits_t bits = xEventGroupWaitBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
if (!(bits & WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT)) {
ESP_LOGE(TAG, "Failed to receive server hello");
SetError(Lang::Strings::SERVER_TIMEOUT);
return false;
}
if (on_audio_channel_opened_ != nullptr) {
on_audio_channel_opened_();
}
return true;
}
std::string WebsocketProtocol::GetHelloMessage() {
// keys: message type, version, audio_params (format, sample_rate, channels)
cJSON* root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "type", "hello");
cJSON_AddNumberToObject(root, "version", version_);
cJSON* features = cJSON_CreateObject();
#if CONFIG_USE_SERVER_AEC
cJSON_AddBoolToObject(features, "aec", true);
#endif
cJSON_AddBoolToObject(features, "mcp", true);
cJSON_AddItemToObject(root, "features", features);
cJSON_AddStringToObject(root, "transport", "websocket");
cJSON* audio_params = cJSON_CreateObject();
cJSON_AddStringToObject(audio_params, "format", "opus");
cJSON_AddNumberToObject(audio_params, "sample_rate", 16000);
cJSON_AddNumberToObject(audio_params, "channels", 1);
cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS);
cJSON_AddItemToObject(root, "audio_params", audio_params);
auto json_str = cJSON_PrintUnformatted(root);
std::string message(json_str);
cJSON_free(json_str);
cJSON_Delete(root);
return message;
}
void WebsocketProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "websocket") != 0) {
ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
return;
}
auto session_id = cJSON_GetObjectItem(root, "session_id");
if (cJSON_IsString(session_id)) {
session_id_ = session_id->valuestring;
ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
}
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (cJSON_IsObject(audio_params)) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (cJSON_IsNumber(sample_rate)) {
server_sample_rate_ = sample_rate->valueint;
}
auto frame_duration = cJSON_GetObjectItem(audio_params, "frame_duration");
if (cJSON_IsNumber(frame_duration)) {
server_frame_duration_ = frame_duration->valueint;
}
}
xEventGroupSetBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT);
}
#include "websocket_protocol.h"
#include "board.h"
#include "system_info.h"
#include "application.h"
#include "settings.h"
#include <cstring>
#include <cJSON.h>
#include <esp_log.h>
#include <arpa/inet.h>
#include "assets/lang_config.h"
#define TAG "WS"
WebsocketProtocol::WebsocketProtocol() {
event_group_handle_ = xEventGroupCreate();
}
WebsocketProtocol::~WebsocketProtocol() {
vEventGroupDelete(event_group_handle_);
}
bool WebsocketProtocol::Start() {
// Only connect to server when audio channel is needed
return true;
}
bool WebsocketProtocol::SendAudio(std::unique_ptr<AudioStreamPacket> packet) {
if (websocket_ == nullptr || !websocket_->IsConnected()) {
return false;
}
if (version_ == 2) {
std::string serialized;
serialized.resize(sizeof(BinaryProtocol2) + packet->payload.size());
auto bp2 = (BinaryProtocol2*)serialized.data();
bp2->version = htons(version_);
bp2->type = 0;
bp2->reserved = 0;
bp2->timestamp = htonl(packet->timestamp);
bp2->payload_size = htonl(packet->payload.size());
memcpy(bp2->payload, packet->payload.data(), packet->payload.size());
return websocket_->Send(serialized.data(), serialized.size(), true);
} else if (version_ == 3) {
std::string serialized;
serialized.resize(sizeof(BinaryProtocol3) + packet->payload.size());
auto bp3 = (BinaryProtocol3*)serialized.data();
bp3->type = 0;
bp3->reserved = 0;
bp3->payload_size = htons(packet->payload.size());
memcpy(bp3->payload, packet->payload.data(), packet->payload.size());
return websocket_->Send(serialized.data(), serialized.size(), true);
} else {
return websocket_->Send(packet->payload.data(), packet->payload.size(), true);
}
}
bool WebsocketProtocol::SendText(const std::string& text) {
if (websocket_ == nullptr || !websocket_->IsConnected()) {
return false;
}
if (!websocket_->Send(text)) {
ESP_LOGE(TAG, "Failed to send text: %s", text.c_str());
SetError(Lang::Strings::SERVER_ERROR);
return false;
}
return true;
}
bool WebsocketProtocol::IsAudioChannelOpened() const {
return websocket_ != nullptr && websocket_->IsConnected() && !error_occurred_ && !IsTimeout();
}
void WebsocketProtocol::CloseAudioChannel() {
websocket_.reset();
}
bool WebsocketProtocol::OpenAudioChannel() {
Settings settings("websocket", false);
std::string url = settings.GetString("url");
std::string token = settings.GetString("token");
int version = settings.GetInt("version");
if (version != 0) {
version_ = version;
}
error_occurred_ = false;
auto network = Board::GetInstance().GetNetwork();
websocket_ = network->CreateWebSocket(1);
if (websocket_ == nullptr) {
ESP_LOGE(TAG, "Failed to create websocket");
return false;
}
if (!token.empty()) {
// If token not has a space, add "Bearer " prefix
if (token.find(" ") == std::string::npos) {
token = "Bearer " + token;
}
websocket_->SetHeader("Authorization", token.c_str());
}
websocket_->SetHeader("Protocol-Version", std::to_string(version_).c_str());
websocket_->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str());
websocket_->SetHeader("Client-Id", Board::GetInstance().GetUuid().c_str());
websocket_->OnData([this](const char* data, size_t len, bool binary) {
if (binary) {
if (on_incoming_audio_ != nullptr) {
if (version_ == 2) {
BinaryProtocol2* bp2 = (BinaryProtocol2*)data;
bp2->version = ntohs(bp2->version);
bp2->type = ntohs(bp2->type);
bp2->timestamp = ntohl(bp2->timestamp);
bp2->payload_size = ntohl(bp2->payload_size);
auto payload = (uint8_t*)bp2->payload;
on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
.sample_rate = server_sample_rate_,
.frame_duration = server_frame_duration_,
.timestamp = bp2->timestamp,
.payload = std::vector<uint8_t>(payload, payload + bp2->payload_size)
}));
} else if (version_ == 3) {
BinaryProtocol3* bp3 = (BinaryProtocol3*)data;
bp3->type = bp3->type;
bp3->payload_size = ntohs(bp3->payload_size);
auto payload = (uint8_t*)bp3->payload;
on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
.sample_rate = server_sample_rate_,
.frame_duration = server_frame_duration_,
.timestamp = 0,
.payload = std::vector<uint8_t>(payload, payload + bp3->payload_size)
}));
} else {
on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
.sample_rate = server_sample_rate_,
.frame_duration = server_frame_duration_,
.timestamp = 0,
.payload = std::vector<uint8_t>((uint8_t*)data, (uint8_t*)data + len)
}));
}
}
} else {
// Parse JSON data
auto root = cJSON_Parse(data);
auto type = cJSON_GetObjectItem(root, "type");
if (cJSON_IsString(type)) {
if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root);
} else {
if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
}
} else {
ESP_LOGE(TAG, "Missing message type, data: %s", data);
}
cJSON_Delete(root);
}
last_incoming_time_ = std::chrono::steady_clock::now();
});
websocket_->OnDisconnected([this]() {
ESP_LOGI(TAG, "Websocket disconnected");
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
});
ESP_LOGI(TAG, "Connecting to websocket server: %s with version: %d", url.c_str(), version_);
if (!websocket_->Connect(url.c_str())) {
ESP_LOGE(TAG, "Failed to connect to websocket server");
SetError(Lang::Strings::SERVER_NOT_CONNECTED);
return false;
}
// Send hello message to describe the client
auto message = GetHelloMessage();
if (!SendText(message)) {
return false;
}
// Wait for server hello
EventBits_t bits = xEventGroupWaitBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
if (!(bits & WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT)) {
ESP_LOGE(TAG, "Failed to receive server hello");
SetError(Lang::Strings::SERVER_TIMEOUT);
return false;
}
if (on_audio_channel_opened_ != nullptr) {
on_audio_channel_opened_();
}
return true;
}
std::string WebsocketProtocol::GetHelloMessage() {
// keys: message type, version, audio_params (format, sample_rate, channels)
cJSON* root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "type", "hello");
cJSON_AddNumberToObject(root, "version", version_);
cJSON* features = cJSON_CreateObject();
#if CONFIG_USE_SERVER_AEC
cJSON_AddBoolToObject(features, "aec", true);
#endif
cJSON_AddBoolToObject(features, "mcp", true);
cJSON_AddItemToObject(root, "features", features);
cJSON_AddStringToObject(root, "transport", "websocket");
cJSON* audio_params = cJSON_CreateObject();
cJSON_AddStringToObject(audio_params, "format", "opus");
cJSON_AddNumberToObject(audio_params, "sample_rate", 16000);
cJSON_AddNumberToObject(audio_params, "channels", 1);
cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS);
cJSON_AddItemToObject(root, "audio_params", audio_params);
auto json_str = cJSON_PrintUnformatted(root);
std::string message(json_str);
cJSON_free(json_str);
cJSON_Delete(root);
return message;
}
void WebsocketProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "websocket") != 0) {
ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
return;
}
auto session_id = cJSON_GetObjectItem(root, "session_id");
if (cJSON_IsString(session_id)) {
session_id_ = session_id->valuestring;
ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
}
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (cJSON_IsObject(audio_params)) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (cJSON_IsNumber(sample_rate)) {
server_sample_rate_ = sample_rate->valueint;
}
auto frame_duration = cJSON_GetObjectItem(audio_params, "frame_duration");
if (cJSON_IsNumber(frame_duration)) {
server_frame_duration_ = frame_duration->valueint;
}
}
xEventGroupSetBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT);
}

View File

@@ -1,34 +1,34 @@
#ifndef _WEBSOCKET_PROTOCOL_H_
#define _WEBSOCKET_PROTOCOL_H_
#include "protocol.h"
#include <web_socket.h>
#include <freertos/FreeRTOS.h>
#include <freertos/event_groups.h>
#define WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT (1 << 0)
class WebsocketProtocol : public Protocol {
public:
WebsocketProtocol();
~WebsocketProtocol();
bool Start() override;
bool SendAudio(std::unique_ptr<AudioStreamPacket> packet) override;
bool OpenAudioChannel() override;
void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override;
private:
EventGroupHandle_t event_group_handle_;
std::unique_ptr<WebSocket> websocket_;
int version_ = 1;
void ParseServerHello(const cJSON* root);
bool SendText(const std::string& text) override;
std::string GetHelloMessage();
};
#endif
#ifndef _WEBSOCKET_PROTOCOL_H_
#define _WEBSOCKET_PROTOCOL_H_
#include "protocol.h"
#include <web_socket.h>
#include <freertos/FreeRTOS.h>
#include <freertos/event_groups.h>
#define WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT (1 << 0)
class WebsocketProtocol : public Protocol {
public:
WebsocketProtocol();
~WebsocketProtocol();
bool Start() override;
bool SendAudio(std::unique_ptr<AudioStreamPacket> packet) override;
bool OpenAudioChannel() override;
void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override;
private:
EventGroupHandle_t event_group_handle_;
std::unique_ptr<WebSocket> websocket_;
int version_ = 1;
void ParseServerHello(const cJSON* root);
bool SendText(const std::string& text) override;
std::string GetHelloMessage();
};
#endif