#include "rpc.h" #include "rpccommands.h" #include "log.h" #include #include #include #include // Timeout once the packet has started const systime_t timeout = S2ST(15); // Timeout between packets const systime_t header_timeout = S2ST(120); typedef enum { READ_PACKET, WRITE_RESPONSE, DONE } rpc_state_t; struct _rpc_t { rpc_state_t state; // IO stream we are communicating with BaseChannel *io; // Data about the command being executed RPCCommand command; const pb_field_t *request_fields; const pb_field_t *response_fields; // Received packet length and checksum size_t request_payload_size; uint8_t request_checksum; // Response packet length and checksum const void* response_prepared_for; size_t response_payload_size; uint8_t response_checksum; // Should rpc_server exit after this call? bool abort; }; static bool stream_read(pb_istream_t *stream, uint8_t *buf, size_t count) { rpc_t *rpc = (rpc_t*)stream->state; if (buf == NULL) { int dummy; while (count--) { dummy = chIOGetTimeout(rpc->io, timeout); if (dummy < 0) { LOG("RPC timeout while skipping value, %d remaining", count); return false; } rpc->request_checksum ^= dummy; } } else { if (chIOReadTimeout(rpc->io, buf, count, timeout) != count) { LOG("RPC timeout while reading, %d total", count); return false; } for (size_t i = 0; i < count; i++) rpc->request_checksum ^= buf[i]; } return true; } static bool stream_write(pb_ostream_t *stream, const uint8_t *buf, size_t count) { rpc_t *rpc = (rpc_t*)stream->state; for (size_t i = 0; i < count; i++) rpc->response_checksum ^= buf[i]; return chIOWriteTimeout(rpc->io, buf, count, timeout) == count; } bool rpc_receive(rpc_t *rpc, void *dest) { bool status = true; if (rpc->state != READ_PACKET) { LOG("rpc_receive called in wrong state %d", rpc->state); rpc_abort(rpc); return false; } pb_istream_t stream = {&stream_read, rpc, rpc->request_payload_size}; if (dest != NULL) { status = pb_decode(&stream, rpc->request_fields, dest); } if (stream.bytes_left != 0) { if (!pb_read(&stream, NULL, stream.bytes_left)) { // Read timeout, abort rpc_abort(rpc); return false; } } int checksum = chIOGetTimeout(rpc->io, timeout); if (checksum != rpc->request_checksum) { LOG("RPC request checksum error"); rpc_abort(rpc); return false; } rpc->state = WRITE_RESPONSE; if (!status) { LOG("pb_decode failed, messagetype %d", rpc->command); rpc_protocol_error(rpc, "pb_decode failed"); return false; } return true; } bool rpc_prepare(rpc_t *rpc, const void *src) { pb_ostream_t sizestream = {0}; if (!pb_encode(&sizestream, rpc->response_fields, src)) return false; if (sizestream.bytes_written > 65535 - 2) { LOG("Packet exceeds maximum size, %d", sizestream.bytes_written); return false; } rpc->response_prepared_for = src; rpc->response_payload_size = sizestream.bytes_written; return true; } bool rpc_respond(rpc_t *rpc, const void *src) { if (rpc->state == READ_PACKET) { if (!rpc_receive(rpc, NULL)) { return false; } } if (rpc->state != WRITE_RESPONSE) { LOG("rpc_respond called with wrong state %d", rpc->state); rpc_abort(rpc); return false; } if (src != NULL && rpc->response_prepared_for != src) { if (!rpc_prepare(rpc, src)) { LOG("RPC late prepare failed, aborting"); rpc_abort(rpc); return false; } } if (src == NULL) { rpc->response_payload_size = 0; } int length = rpc->response_payload_size + 2; uint8_t header[3] = {length >> 8, length & 0xFF, rpc->command}; if (chIOWriteTimeout(rpc->io, header, 3, timeout) != 3) { LOG("RPC response header write timeouted, aborting"); rpc_abort(rpc); return false; } rpc->response_checksum = 0xFF ^ header[2]; pb_ostream_t stream = {&stream_write, rpc, rpc->response_payload_size, 0}; if (src != NULL && !pb_encode(&stream, rpc->response_fields, src)) { LOG("RPC pb_encode failed midway"); rpc_abort(rpc); return false; } if (stream.bytes_written != stream.max_size) { LOG("pb_encode size changed"); rpc_abort(rpc); return false; } if (chIOPutTimeout(rpc->io, rpc->response_checksum, timeout) != Q_OK) { LOG("RPC checksum write timeouted"); rpc_abort(rpc); return false; } rpc->state = DONE; return true; } static bool rpc_sendpacket(rpc_t *rpc, RPCCommand msgtype, const pb_field_t *fields, const void *src) { pb_ostream_t sizestream = {0}; if (!pb_encode(&sizestream, fields, src)) return false; int length = sizestream.bytes_written + 2; uint8_t header[3] = {length >> 8, length & 0xFF, msgtype}; if (chIOWriteTimeout(rpc->io, header, 3, timeout) != 3) return false; rpc->response_checksum = 0xFF ^ header[2]; pb_ostream_t stream = {&stream_write, rpc, sizestream.bytes_written, 0}; if (!pb_encode(&stream, fields, src)) return false; if (chIOPutTimeout(rpc->io, rpc->response_checksum, timeout) != Q_OK) return false; rpc->state = DONE; return true; } void rpc_error(rpc_t *rpc, const ErrorResponse *error) { if (rpc->state == READ_PACKET) { if (!rpc_receive(rpc, NULL)) { return; } } if (rpc->state != WRITE_RESPONSE) { LOG("rpc_error called with wrong state %d", rpc->state); rpc_abort(rpc); return; } if (!rpc_sendpacket(rpc, RPCCommand_Error, ErrorResponse_fields, error)) { LOG("rpc_error failed"); rpc_abort(rpc); } } static bool encode_string(pb_ostream_t *stream, const pb_field_t *field, const void *arg) { if (!pb_encode_tag_for_field(stream, field)) return false; return pb_enc_string(stream, field, arg); } void rpc_protocol_error(rpc_t *rpc, const char *message) { ErrorResponse error = {0}; error.scope = ErrorResponse_ErrorScope_PROTOCOL; error.message.funcs.encode = &encode_string; error.message.arg = (void*)message; rpc_error(rpc, &error); } void rpc_resources_error(rpc_t *rpc, const char *message) { ErrorResponse error = {0}; error.scope = ErrorResponse_ErrorScope_RESOURCES; error.message.funcs.encode = &encode_string; error.message.arg = (void*)message; rpc_error(rpc, &error); } void rpc_abort(rpc_t *rpc) { LOG("rpc_abort called"); rpc->abort = true; } bool rpc_login(BaseChannel *io) { const uint8_t loginpacket[] = { 0x00, 0x04, RPCCommand_Login, 0x08, 0x01, 0xFF ^ 0x08 ^ 0x01 ^ RPCCommand_Login }; if (chIOWriteTimeout(io, loginpacket, sizeof(loginpacket), timeout) != sizeof(loginpacket)) return false; uint8_t responseheader[3]; if (chIOReadTimeout(io, responseheader, 3, timeout) != 3) return false; if (responseheader[2] != RPCCommand_Login) { LOG("RPC proxy gave error on login"); return false; } // Just skip the response contents for now... int length = (responseheader[0] << 8) | responseheader[1]; length -= 1; while (length && chIOGetTimeout(io, timeout) >= 0) length--; return length == 0; } static void rpc_dispatch(rpc_t *rpc) { switch (rpc->command) { #define X(name) \ case RPCCommand_ ## name: \ LOG("RPC: " #name); \ rpc->request_fields = name ## Request_fields; \ rpc->response_fields = name ## Response_fields; \ rpc_ ## name(rpc); \ break; RPCCOMMANDS_XMACRO #undef X default: LOG("Unknown RPC command %d", rpc->command); rpc_protocol_error(rpc, "Unknown command"); } } void rpc_server(BaseChannel *io) { if (!rpc_login(io)) { LOG("RPC login failed"); return; } rpc_t rpc; do { uint8_t header[3]; if (chIOReadTimeout(io, header, 3, header_timeout) != 3) { LOG("RPC header read timeouted"); return; } rpc.state = READ_PACKET; rpc.io = io; rpc.command = header[2]; rpc.request_fields = rpc.response_fields = NULL; rpc.request_payload_size = ((header[0] << 8) | header[1]) - 2; rpc.request_checksum = 0xFF ^ header[2]; rpc.response_prepared_for = NULL; rpc.response_payload_size = 0; rpc.response_checksum = 0; rpc.abort = false; rpc_dispatch(&rpc); } while (!rpc.abort); }