118 lines
4.1 KiB
C++
118 lines
4.1 KiB
C++
/*
|
|
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
|
*
|
|
* Use of this source code is governed by a BSD-style license
|
|
* that can be found in the LICENSE file in the root of the source
|
|
* tree. An additional intellectual property rights grant can be found
|
|
* in the file PATENTS. All contributing project authors may
|
|
* be found in the AUTHORS file in the root of the source tree.
|
|
*/
|
|
#ifndef NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
|
|
#define NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
|
|
|
|
#include <deque>
|
|
#include <memory>
|
|
#include <set>
|
|
#include <vector>
|
|
|
|
#include "api/array_view.h"
|
|
#include "net/dcsctp/public/dcsctp_socket.h"
|
|
|
|
namespace dcsctp {
|
|
namespace dcsctp_fuzzers {
|
|
|
|
// A fake timeout used during fuzzing.
|
|
class FuzzerTimeout : public Timeout {
|
|
public:
|
|
explicit FuzzerTimeout(std::set<TimeoutID>& active_timeouts)
|
|
: active_timeouts_(active_timeouts) {}
|
|
|
|
void Start(DurationMs duration_ms, TimeoutID timeout_id) override {
|
|
// Start is only allowed to be called on stopped or expired timeouts.
|
|
if (timeout_id_.has_value()) {
|
|
// It has been started before, but maybe it expired. Ensure that it's not
|
|
// running at least.
|
|
RTC_DCHECK(active_timeouts_.find(*timeout_id_) == active_timeouts_.end());
|
|
}
|
|
timeout_id_ = timeout_id;
|
|
RTC_DCHECK(active_timeouts_.insert(timeout_id).second);
|
|
}
|
|
|
|
void Stop() override {
|
|
// Stop is only allowed to be called on active timeouts. Not stopped or
|
|
// expired.
|
|
RTC_DCHECK(timeout_id_.has_value());
|
|
RTC_DCHECK(active_timeouts_.erase(*timeout_id_) == 1);
|
|
timeout_id_ = absl::nullopt;
|
|
}
|
|
|
|
// A set of all active timeouts, managed by `FuzzerCallbacks`.
|
|
std::set<TimeoutID>& active_timeouts_;
|
|
// If present, the timout is active and will expire reported as `timeout_id`.
|
|
absl::optional<TimeoutID> timeout_id_;
|
|
};
|
|
|
|
class FuzzerCallbacks : public DcSctpSocketCallbacks {
|
|
public:
|
|
static constexpr int kRandomValue = 42;
|
|
void SendPacket(rtc::ArrayView<const uint8_t> data) override {
|
|
sent_packets_.emplace_back(std::vector<uint8_t>(data.begin(), data.end()));
|
|
}
|
|
std::unique_ptr<Timeout> CreateTimeout() override {
|
|
return std::make_unique<FuzzerTimeout>(active_timeouts_);
|
|
}
|
|
TimeMs TimeMillis() override { return TimeMs(42); }
|
|
uint32_t GetRandomInt(uint32_t low, uint32_t high) override {
|
|
return kRandomValue;
|
|
}
|
|
void OnMessageReceived(DcSctpMessage message) override {}
|
|
void OnError(ErrorKind error, absl::string_view message) override {}
|
|
void OnAborted(ErrorKind error, absl::string_view message) override {}
|
|
void OnConnected() override {}
|
|
void OnClosed() override {}
|
|
void OnConnectionRestarted() override {}
|
|
void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,
|
|
absl::string_view reason) override {}
|
|
void OnStreamsResetPerformed(
|
|
rtc::ArrayView<const StreamID> outgoing_streams) override {}
|
|
void OnIncomingStreamsReset(
|
|
rtc::ArrayView<const StreamID> incoming_streams) override {}
|
|
void NotifyOutgoingMessageBufferEmpty() override {}
|
|
|
|
std::vector<uint8_t> ConsumeSentPacket() {
|
|
if (sent_packets_.empty()) {
|
|
return {};
|
|
}
|
|
std::vector<uint8_t> ret = sent_packets_.front();
|
|
sent_packets_.pop_front();
|
|
return ret;
|
|
}
|
|
|
|
// Given an index among the active timeouts, will expire that one.
|
|
absl::optional<TimeoutID> ExpireTimeout(size_t index) {
|
|
if (index < active_timeouts_.size()) {
|
|
auto it = active_timeouts_.begin();
|
|
std::advance(it, index);
|
|
TimeoutID timeout_id = *it;
|
|
active_timeouts_.erase(it);
|
|
return timeout_id;
|
|
}
|
|
return absl::nullopt;
|
|
}
|
|
|
|
private:
|
|
// Needs to be ordered, to allow fuzzers to expire timers.
|
|
std::set<TimeoutID> active_timeouts_;
|
|
std::deque<std::vector<uint8_t>> sent_packets_;
|
|
};
|
|
|
|
// Given some fuzzing `data` will send packets to the socket as well as calling
|
|
// API methods.
|
|
void FuzzSocket(DcSctpSocketInterface& socket,
|
|
FuzzerCallbacks& cb,
|
|
rtc::ArrayView<const uint8_t> data);
|
|
|
|
} // namespace dcsctp_fuzzers
|
|
} // namespace dcsctp
|
|
#endif // NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
|