Files
xiaozhi-esp32/managed_components/78__esp-ml307/src/web_socket.cc
2025-09-13 23:40:38 +08:00

438 lines
13 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "web_socket.h"
#include "network_interface.h"
#include <esp_log.h>
#include <cstdlib>
#include <cstring>
#include <esp_pthread.h>
#define TAG "WebSocket"
static std::string base64_encode(const unsigned char *data, size_t len) {
const char *base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string encoded;
unsigned char char_array_3[3];
unsigned char char_array_4[4];
size_t i = 0;
while (i < len) {
size_t chunk_size = std::min((size_t)3, len - i);
for (size_t j = 0; j < 3; j++) {
char_array_3[j] = (j < chunk_size) ? data[i + j] : 0;
}
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for (size_t j = 0; j < 4; j++) {
if (j <= chunk_size) {
encoded.push_back(base64_chars[char_array_4[j]]);
} else {
encoded.push_back('=');
}
}
i += chunk_size;
}
return encoded;
}
WebSocket::WebSocket(NetworkInterface* network, int connect_id) : network_(network), connect_id_(connect_id) {
handshake_event_group_ = xEventGroupCreate();
}
WebSocket::~WebSocket() {
if (connected_) {
tcp_->Disconnect();
}
if (handshake_event_group_) {
vEventGroupDelete(handshake_event_group_);
}
}
void WebSocket::SetHeader(const char* key, const char* value) {
headers_[key] = value;
}
void WebSocket::SetReceiveBufferSize(size_t size) {
receive_buffer_size_ = size;
}
bool WebSocket::IsConnected() const {
return connected_;
}
bool WebSocket::Connect(const char* uri) {
std::string uri_str(uri);
std::string protocol, host, port, path;
size_t pos = 0;
size_t next_pos = 0;
// 解析协议
next_pos = uri_str.find("://");
if (next_pos == std::string::npos) {
ESP_LOGE(TAG, "Invalid URI format");
return false;
}
protocol = uri_str.substr(0, next_pos);
pos = next_pos + 3;
// 解析主机
next_pos = uri_str.find(':', pos);
if (next_pos == std::string::npos) {
next_pos = uri_str.find('/', pos);
if (next_pos == std::string::npos) {
host = uri_str.substr(pos);
path = "/";
} else {
host = uri_str.substr(pos, next_pos - pos);
path = uri_str.substr(next_pos);
}
port = (protocol == "wss") ? "443" : "80";
} else {
host = uri_str.substr(pos, next_pos - pos);
pos = next_pos + 1;
// 解析端口
next_pos = uri_str.find('/', pos);
if (next_pos == std::string::npos) {
port = uri_str.substr(pos);
path = "/";
} else {
port = uri_str.substr(pos, next_pos - pos);
path = uri_str.substr(next_pos);
}
}
ESP_LOGD(TAG, "Connecting to %s://%s:%s%s", protocol.c_str(), host.c_str(), port.c_str(), path.c_str());
// 设置 WebSocket 特定的头部
SetHeader("Upgrade", "websocket");
SetHeader("Connection", "Upgrade");
SetHeader("Sec-WebSocket-Version", "13");
// 生成随机的 Sec-WebSocket-Key
char key[25];
for (int i = 0; i < 16; ++i) {
key[i] = rand() % 256;
}
std::string base64_key = base64_encode(reinterpret_cast<const unsigned char*>(key), 16);
SetHeader("Sec-WebSocket-Key", base64_key.c_str());
if (protocol == "wss" || protocol == "https") {
tcp_ = network_->CreateSsl(connect_id_);
} else {
tcp_ = network_->CreateTcp(connect_id_);
}
connected_ = false;
// 使用 tcp 建立连接
if (!tcp_->Connect(host, std::stoi(port))) {
ESP_LOGE(TAG, "Failed to connect to server");
return false;
}
// 发送 WebSocket 握手请求
std::string request = "GET " + path + " HTTP/1.1\r\n";
if (headers_.find("Host") == headers_.end()) {
request += "Host: " + host + "\r\n";
}
for (const auto& header : headers_) {
request += header.first + ": " + header.second + "\r\n";
}
request += "\r\n";
if (tcp_->Send(request) < 0) {
ESP_LOGE(TAG, "Failed to send WebSocket handshake request");
return false;
}
// 清除事件位
xEventGroupClearBits(handshake_event_group_, HANDSHAKE_SUCCESS_BIT | HANDSHAKE_FAILED_BIT);
// 设置数据接收回调来处理握手和后续的WebSocket帧
tcp_->OnStream([this](const std::string& data) {
this->OnTcpData(data);
});
// 设置断开连接回调
tcp_->OnDisconnected([this]() {
if (connected_) {
connected_ = false;
if (on_disconnected_) {
on_disconnected_();
}
}
});
// 等待握手完成超时时间10秒
EventBits_t bits = xEventGroupWaitBits(
handshake_event_group_,
HANDSHAKE_SUCCESS_BIT | HANDSHAKE_FAILED_BIT,
pdFALSE, // 不清除事件位
pdFALSE, // 等待任意一个事件位
pdMS_TO_TICKS(10000) // 10秒超时
);
if (bits & HANDSHAKE_SUCCESS_BIT) {
connected_ = true;
if (on_connected_) {
on_connected_();
}
return true;
} else if (bits & HANDSHAKE_FAILED_BIT) {
ESP_LOGE(TAG, "WebSocket handshake failed");
if (on_error_) {
on_error_(-1);
}
return false;
} else {
ESP_LOGE(TAG, "WebSocket handshake timeout");
return false;
}
}
bool WebSocket::Send(const std::string& data) {
return Send(data.data(), data.size(), false);
}
bool WebSocket::Send(const void* data, size_t len, bool binary, bool fin) {
if (len > 65535) {
ESP_LOGE(TAG, "Data too large, maximum supported size is 65535 bytes");
return false;
}
std::string frame;
frame.reserve(len + 8); // 最大可能的帧大小2字节帧头 + 2字节长度 + 4字节mask
// 第一个字节FIN 位 + 操作码
uint8_t first_byte = (fin ? 0x80 : 0x00);
if (binary) {
first_byte |= 0x02; // 二进制帧
} else if (!continuation_) {
first_byte |= 0x01; // 文本帧
} // 否则操作码为0延续帧
frame.push_back(static_cast<char>(first_byte));
// 第二个字节MASK 位 + 有效载荷长度
if (len < 126) {
frame.push_back(static_cast<char>(0x80 | len)); // 设置MASK位
} else {
frame.push_back(static_cast<char>(0x80 | 126)); // 设置MASK位
frame.push_back(static_cast<char>((len >> 8) & 0xFF));
frame.push_back(static_cast<char>(len & 0xFF));
}
// 生成随机的4字节mask
uint8_t mask[4];
for (int i = 0; i < 4; ++i) {
mask[i] = rand() & 0xFF;
}
frame.append(reinterpret_cast<const char*>(mask), 4);
// 添加并mask处理有效载荷
const uint8_t* payload = static_cast<const uint8_t*>(data);
for (size_t i = 0; i < len; ++i) {
frame.push_back(static_cast<char>(payload[i] ^ mask[i % 4]));
}
// 更新continuation_状态
continuation_ = !fin;
// 发送帧
std::lock_guard<std::mutex> lock(send_mutex_);
return tcp_->Send(frame) >= 0;
}
void WebSocket::Ping() {
SendControlFrame(0x9, nullptr, 0);
}
void WebSocket::Close() {
if (connected_) {
SendControlFrame(0x8, nullptr, 0);
}
}
void WebSocket::OnConnected(std::function<void()> callback) {
on_connected_ = callback;
}
void WebSocket::OnDisconnected(std::function<void()> callback) {
on_disconnected_ = callback;
}
void WebSocket::OnData(std::function<void(const char*, size_t, bool binary)> callback) {
on_data_ = callback;
}
void WebSocket::OnError(std::function<void(int)> callback) {
on_error_ = callback;
}
void WebSocket::OnTcpData(const std::string& data) {
// 将新数据追加到接收缓冲区
receive_buffer_.append(data);
if (!handshake_completed_) {
// 检查握手响应
size_t pos = receive_buffer_.find("\r\n\r\n");
if (pos != std::string::npos) {
std::string handshake_response = receive_buffer_.substr(0, pos + 4);
receive_buffer_ = receive_buffer_.substr(pos + 4);
if (handshake_response.find("HTTP/1.1 101") != std::string::npos) {
handshake_completed_ = true;
// 设置握手成功事件
xEventGroupSetBits(handshake_event_group_, HANDSHAKE_SUCCESS_BIT);
} else {
ESP_LOGE(TAG, "WebSocket handshake failed");
// 设置握手失败事件
xEventGroupSetBits(handshake_event_group_, HANDSHAKE_FAILED_BIT);
return;
}
} else {
// 握手响应未完整接收
return;
}
}
// 处理WebSocket帧
static std::vector<char> current_message;
static bool is_fragmented = false;
static bool is_binary = false;
size_t buffer_offset = 0;
const char* buffer = receive_buffer_.c_str();
size_t buffer_size = receive_buffer_.size();
while (buffer_offset < buffer_size) {
if (buffer_size - buffer_offset < 2) break; // 需要更多数据
uint8_t opcode = buffer[buffer_offset] & 0x0F;
bool fin = (buffer[buffer_offset] & 0x80) != 0;
uint8_t mask = buffer[buffer_offset + 1] & 0x80;
uint64_t payload_length = buffer[buffer_offset + 1] & 0x7F;
size_t header_length = 2;
if (payload_length == 126) {
if (buffer_size - buffer_offset < 4) break; // 需要更多数据
payload_length = (buffer[buffer_offset + 2] << 8) | buffer[buffer_offset + 3];
header_length += 2;
} else if (payload_length == 127) {
if (buffer_size - buffer_offset < 10) break; // 需要更多数据
payload_length = 0;
for (int i = 0; i < 8; ++i) {
payload_length = (payload_length << 8) | buffer[buffer_offset + 2 + i];
}
header_length += 8;
}
uint8_t mask_key[4] = {0};
if (mask) {
if (buffer_size - buffer_offset < header_length + 4) break; // 需要更多数据
memcpy(mask_key, buffer + buffer_offset + header_length, 4);
header_length += 4;
}
if (buffer_size - buffer_offset < header_length + payload_length) break; // 需要更多数据
// 解码有效载荷
std::vector<char> payload(payload_length);
memcpy(payload.data(), buffer + buffer_offset + header_length, payload_length);
if (mask) {
for (size_t i = 0; i < payload_length; ++i) {
payload[i] ^= mask_key[i % 4];
}
}
// 处理帧
switch (opcode) {
case 0x0: // 延续帧
case 0x1: // 文本帧
case 0x2: // 二进制帧
if (opcode != 0x0 && is_fragmented) {
ESP_LOGE(TAG, "Received new message frame while still fragmenting");
break;
}
if (opcode != 0x0) {
is_fragmented = !fin;
is_binary = (opcode == 0x2);
current_message.clear();
}
current_message.insert(current_message.end(), payload.begin(), payload.end());
if (fin) {
if (on_data_) {
on_data_(current_message.data(), current_message.size(), is_binary);
}
current_message.clear();
is_fragmented = false;
}
break;
case 0x8: // 关闭帧
connected_ = false;
if (on_disconnected_) {
on_disconnected_();
}
break;
case 0x9: // Ping
std::thread([this, payload, payload_length]() {
SendControlFrame(0xA, payload.data(), payload_length);
}).detach();
break;
case 0xA: // Pong
break;
default:
ESP_LOGE(TAG, "Unknown opcode: %d", opcode);
break;
}
buffer_offset += header_length + payload_length;
}
// 保留未处理的数据
if (buffer_offset > 0) {
receive_buffer_ = receive_buffer_.substr(buffer_offset);
}
}
bool WebSocket::SendControlFrame(uint8_t opcode, const void* data, size_t len) {
if (len > 125) {
ESP_LOGE(TAG, "控制帧有效载荷过大");
return false;
}
std::string frame;
frame.reserve(len + 6); // 帧头 + 掩码 + 有效载荷
// 第一个字节FIN 位 + 操作码
frame.push_back(static_cast<char>(0x80 | opcode));
// 第二个字节MASK 位 + 有效载荷长度
frame.push_back(static_cast<char>(0x80 | len));
// 生成随机的4字节掩码
uint8_t mask[4];
for (int i = 0; i < 4; ++i) {
mask[i] = rand() & 0xFF;
}
frame.append(reinterpret_cast<const char*>(mask), 4);
// 添加并掩码处理有效载荷
const uint8_t* payload = static_cast<const uint8_t*>(data);
for (size_t i = 0; i < len; ++i) {
frame.push_back(static_cast<char>(payload[i] ^ mask[i % 4]));
}
// 发送帧
std::lock_guard<std::mutex> lock(send_mutex_);
return tcp_->Send(frame) >= 0;
}