Botan
1.11.15
|
00001 /* 00002 * TLS Handshake IO 00003 * (C) 2012,2014 Jack Lloyd 00004 * 00005 * Botan is released under the Simplified BSD License (see license.txt) 00006 */ 00007 00008 #include <botan/internal/tls_handshake_io.h> 00009 #include <botan/internal/tls_messages.h> 00010 #include <botan/internal/tls_record.h> 00011 #include <botan/internal/tls_seq_numbers.h> 00012 #include <botan/exceptn.h> 00013 #include <chrono> 00014 00015 namespace Botan { 00016 00017 namespace TLS { 00018 00019 namespace { 00020 00021 inline size_t load_be24(const byte q[3]) 00022 { 00023 return make_u32bit(0, 00024 q[0], 00025 q[1], 00026 q[2]); 00027 } 00028 00029 void store_be24(byte out[3], size_t val) 00030 { 00031 out[0] = get_byte<u32bit>(1, val); 00032 out[1] = get_byte<u32bit>(2, val); 00033 out[2] = get_byte<u32bit>(3, val); 00034 } 00035 00036 } 00037 00038 Protocol_Version Stream_Handshake_IO::initial_record_version() const 00039 { 00040 return Protocol_Version::TLS_V10; 00041 } 00042 00043 void Stream_Handshake_IO::add_record(const std::vector<byte>& record, 00044 Record_Type record_type, u64bit) 00045 { 00046 if(record_type == HANDSHAKE) 00047 { 00048 m_queue.insert(m_queue.end(), record.begin(), record.end()); 00049 } 00050 else if(record_type == CHANGE_CIPHER_SPEC) 00051 { 00052 if(record.size() != 1 || record[0] != 1) 00053 throw Decoding_Error("Invalid ChangeCipherSpec"); 00054 00055 // Pretend it's a regular handshake message of zero length 00056 const byte ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 }; 00057 m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs)); 00058 } 00059 else 00060 throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing"); 00061 } 00062 00063 std::pair<Handshake_Type, std::vector<byte>> 00064 Stream_Handshake_IO::get_next_record(bool) 00065 { 00066 if(m_queue.size() >= 4) 00067 { 00068 const size_t length = make_u32bit(0, m_queue[1], m_queue[2], m_queue[3]); 00069 00070 if(m_queue.size() >= length + 4) 00071 { 00072 Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]); 00073 00074 std::vector<byte> contents(m_queue.begin() + 4, 00075 m_queue.begin() + 4 + length); 00076 00077 m_queue.erase(m_queue.begin(), m_queue.begin() + 4 + length); 00078 00079 return std::make_pair(type, contents); 00080 } 00081 } 00082 00083 return std::make_pair(HANDSHAKE_NONE, std::vector<byte>()); 00084 } 00085 00086 std::vector<byte> 00087 Stream_Handshake_IO::format(const std::vector<byte>& msg, 00088 Handshake_Type type) const 00089 { 00090 std::vector<byte> send_buf(4 + msg.size()); 00091 00092 const size_t buf_size = msg.size(); 00093 00094 send_buf[0] = type; 00095 00096 store_be24(&send_buf[1], buf_size); 00097 00098 copy_mem(&send_buf[4], &msg[0], msg.size()); 00099 00100 return send_buf; 00101 } 00102 00103 std::vector<byte> Stream_Handshake_IO::send(const Handshake_Message& msg) 00104 { 00105 const std::vector<byte> msg_bits = msg.serialize(); 00106 00107 if(msg.type() == HANDSHAKE_CCS) 00108 { 00109 m_send_hs(CHANGE_CIPHER_SPEC, msg_bits); 00110 return std::vector<byte>(); // not included in handshake hashes 00111 } 00112 00113 const std::vector<byte> buf = format(msg_bits, msg.type()); 00114 m_send_hs(HANDSHAKE, buf); 00115 return buf; 00116 } 00117 00118 Protocol_Version Datagram_Handshake_IO::initial_record_version() const 00119 { 00120 return Protocol_Version::DTLS_V10; 00121 } 00122 00123 namespace { 00124 00125 // 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1 00126 const u64bit INITIAL_TIMEOUT = 1*1000; 00127 const u64bit MAXIMUM_TIMEOUT = 60*1000; 00128 00129 u64bit steady_clock_ms() 00130 { 00131 return std::chrono::duration_cast<std::chrono::milliseconds>( 00132 std::chrono::steady_clock::now().time_since_epoch()).count(); 00133 } 00134 00135 } 00136 00137 bool Datagram_Handshake_IO::timeout_check() 00138 { 00139 if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) 00140 { 00141 /* 00142 If we haven't written anything yet obviously no timeout. 00143 Also no timeout possible if we are mid-flight, 00144 */ 00145 return false; 00146 } 00147 00148 const u64bit ms_since_write = steady_clock_ms() - m_last_write; 00149 00150 if(ms_since_write < m_next_timeout) 00151 return false; 00152 00153 std::vector<u16bit> flight; 00154 if(m_flights.size() == 1) 00155 flight = m_flights.at(0); // lost initial client hello 00156 else 00157 flight = m_flights.at(m_flights.size() - 2); 00158 00159 BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit"); 00160 00161 u16bit epoch = m_flight_data[flight[0]].epoch; 00162 00163 for(auto msg_seq : flight) 00164 { 00165 auto& msg = m_flight_data[msg_seq]; 00166 00167 if(msg.epoch != epoch) 00168 { 00169 // Epoch gap: insert the CCS 00170 std::vector<byte> ccs(1, 1); 00171 m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs); 00172 } 00173 00174 send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits); 00175 epoch = msg.epoch; 00176 } 00177 00178 m_next_timeout = std::min(2 * m_next_timeout, MAXIMUM_TIMEOUT); 00179 return true; 00180 } 00181 00182 void Datagram_Handshake_IO::add_record(const std::vector<byte>& record, 00183 Record_Type record_type, 00184 u64bit record_sequence) 00185 { 00186 const u16bit epoch = static_cast<u16bit>(record_sequence >> 48); 00187 00188 if(record_type == CHANGE_CIPHER_SPEC) 00189 { 00190 // TODO: check this is otherwise empty 00191 m_ccs_epochs.insert(epoch); 00192 return; 00193 } 00194 00195 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12; 00196 00197 const byte* record_bits = &record[0]; 00198 size_t record_size = record.size(); 00199 00200 while(record_size) 00201 { 00202 if(record_size < DTLS_HANDSHAKE_HEADER_LEN) 00203 return; // completely bogus? at least degenerate/weird 00204 00205 const byte msg_type = record_bits[0]; 00206 const size_t msg_len = load_be24(&record_bits[1]); 00207 const u16bit message_seq = load_be<u16bit>(&record_bits[4], 0); 00208 const size_t fragment_offset = load_be24(&record_bits[6]); 00209 const size_t fragment_length = load_be24(&record_bits[9]); 00210 00211 const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length; 00212 00213 if(record_size < total_size) 00214 throw Decoding_Error("Bad lengths in DTLS header"); 00215 00216 if(message_seq >= m_in_message_seq) 00217 { 00218 m_messages[message_seq].add_fragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN], 00219 fragment_length, 00220 fragment_offset, 00221 epoch, 00222 msg_type, 00223 msg_len); 00224 } 00225 else 00226 { 00227 // TODO: detect retransmitted flight 00228 } 00229 00230 record_bits += total_size; 00231 record_size -= total_size; 00232 } 00233 } 00234 00235 std::pair<Handshake_Type, std::vector<byte>> 00236 Datagram_Handshake_IO::get_next_record(bool expecting_ccs) 00237 { 00238 // Expecting a message means the last flight is concluded 00239 if(!m_flights.rbegin()->empty()) 00240 m_flights.push_back(std::vector<u16bit>()); 00241 00242 if(expecting_ccs) 00243 { 00244 if(!m_messages.empty()) 00245 { 00246 const u16bit current_epoch = m_messages.begin()->second.epoch(); 00247 00248 if(m_ccs_epochs.count(current_epoch)) 00249 return std::make_pair(HANDSHAKE_CCS, std::vector<byte>()); 00250 } 00251 00252 return std::make_pair(HANDSHAKE_NONE, std::vector<byte>()); 00253 } 00254 00255 auto i = m_messages.find(m_in_message_seq); 00256 00257 if(i == m_messages.end() || !i->second.complete()) 00258 return std::make_pair(HANDSHAKE_NONE, std::vector<byte>()); 00259 00260 m_in_message_seq += 1; 00261 00262 return i->second.message(); 00263 } 00264 00265 void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment( 00266 const byte fragment[], 00267 size_t fragment_length, 00268 size_t fragment_offset, 00269 u16bit epoch, 00270 byte msg_type, 00271 size_t msg_length) 00272 { 00273 if(complete()) 00274 return; // already have entire message, ignore this 00275 00276 if(m_msg_type == HANDSHAKE_NONE) 00277 { 00278 m_epoch = epoch; 00279 m_msg_type = msg_type; 00280 m_msg_length = msg_length; 00281 } 00282 00283 if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch) 00284 throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header"); 00285 00286 if(fragment_offset > m_msg_length) 00287 throw Decoding_Error("Fragment offset past end of message"); 00288 00289 if(fragment_offset + fragment_length > m_msg_length) 00290 throw Decoding_Error("Fragment overlaps past end of message"); 00291 00292 if(fragment_offset == 0 && fragment_length == m_msg_length) 00293 { 00294 m_fragments.clear(); 00295 m_message.assign(fragment, fragment+fragment_length); 00296 } 00297 else 00298 { 00299 /* 00300 * FIXME. This is a pretty lame way to do defragmentation, huge 00301 * overhead with a tree node per byte. 00302 * 00303 * Also should confirm that all overlaps have no changes, 00304 * otherwise we expose ourselves to the classic fingerprinting 00305 * and IDS evasion attacks on IP fragmentation. 00306 */ 00307 for(size_t i = 0; i != fragment_length; ++i) 00308 m_fragments[fragment_offset+i] = fragment[i]; 00309 00310 if(m_fragments.size() == m_msg_length) 00311 { 00312 m_message.resize(m_msg_length); 00313 for(size_t i = 0; i != m_msg_length; ++i) 00314 m_message[i] = m_fragments[i]; 00315 m_fragments.clear(); 00316 } 00317 } 00318 } 00319 00320 bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const 00321 { 00322 return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length); 00323 } 00324 00325 std::pair<Handshake_Type, std::vector<byte>> 00326 Datagram_Handshake_IO::Handshake_Reassembly::message() const 00327 { 00328 if(!complete()) 00329 throw Internal_Error("Datagram_Handshake_IO - message not complete"); 00330 00331 return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message); 00332 } 00333 00334 std::vector<byte> 00335 Datagram_Handshake_IO::format_fragment(const byte fragment[], 00336 size_t frag_len, 00337 u16bit frag_offset, 00338 u16bit msg_len, 00339 Handshake_Type type, 00340 u16bit msg_sequence) const 00341 { 00342 std::vector<byte> send_buf(12 + frag_len); 00343 00344 send_buf[0] = type; 00345 00346 store_be24(&send_buf[1], msg_len); 00347 00348 store_be(msg_sequence, &send_buf[4]); 00349 00350 store_be24(&send_buf[6], frag_offset); 00351 store_be24(&send_buf[9], frag_len); 00352 00353 copy_mem(&send_buf[12], &fragment[0], frag_len); 00354 00355 return send_buf; 00356 } 00357 00358 std::vector<byte> 00359 Datagram_Handshake_IO::format_w_seq(const std::vector<byte>& msg, 00360 Handshake_Type type, 00361 u16bit msg_sequence) const 00362 { 00363 return format_fragment(&msg[0], msg.size(), 0, msg.size(), type, msg_sequence); 00364 } 00365 00366 std::vector<byte> 00367 Datagram_Handshake_IO::format(const std::vector<byte>& msg, 00368 Handshake_Type type) const 00369 { 00370 return format_w_seq(msg, type, m_in_message_seq - 1); 00371 } 00372 00373 namespace { 00374 00375 size_t split_for_mtu(size_t mtu, size_t msg_size) 00376 { 00377 const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers 00378 00379 const size_t parts = (msg_size + mtu) / mtu; 00380 00381 if(parts + DTLS_HEADERS_SIZE > mtu) 00382 return parts + 1; 00383 00384 return parts; 00385 } 00386 00387 } 00388 00389 std::vector<byte> 00390 Datagram_Handshake_IO::send(const Handshake_Message& msg) 00391 { 00392 const std::vector<byte> msg_bits = msg.serialize(); 00393 const u16bit epoch = m_seqs.current_write_epoch(); 00394 const Handshake_Type msg_type = msg.type(); 00395 00396 if(msg_type == HANDSHAKE_CCS) 00397 { 00398 m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits); 00399 return std::vector<byte>(); // not included in handshake hashes 00400 } 00401 00402 // Note: not saving CCS, instead we know it was there due to change in epoch 00403 m_flights.rbegin()->push_back(m_out_message_seq); 00404 m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits); 00405 00406 m_out_message_seq += 1; 00407 m_last_write = steady_clock_ms(); 00408 m_next_timeout = INITIAL_TIMEOUT; 00409 00410 return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits); 00411 } 00412 00413 std::vector<byte> Datagram_Handshake_IO::send_message(u16bit msg_seq, 00414 u16bit epoch, 00415 Handshake_Type msg_type, 00416 const std::vector<byte>& msg_bits) 00417 { 00418 const std::vector<byte> no_fragment = 00419 format_w_seq(msg_bits, msg_type, msg_seq); 00420 00421 if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) 00422 m_send_hs(epoch, HANDSHAKE, no_fragment); 00423 else 00424 { 00425 const size_t parts = split_for_mtu(m_mtu, msg_bits.size()); 00426 00427 const size_t parts_size = (msg_bits.size() + parts) / parts; 00428 00429 size_t frag_offset = 0; 00430 00431 while(frag_offset != msg_bits.size()) 00432 { 00433 const size_t frag_len = 00434 std::min<size_t>(msg_bits.size() - frag_offset, 00435 parts_size); 00436 00437 m_send_hs(epoch, 00438 HANDSHAKE, 00439 format_fragment(&msg_bits[frag_offset], 00440 frag_len, 00441 frag_offset, 00442 msg_bits.size(), 00443 msg_type, 00444 msg_seq)); 00445 00446 frag_offset += frag_len; 00447 } 00448 } 00449 00450 return no_fragment; 00451 } 00452 00453 } 00454 }