From 90d26623c5fc4f0905e9702b57eb73805fbf3069 Mon Sep 17 00:00:00 2001 From: Rain Mark Date: Wed, 26 Jan 2022 02:17:56 -0500 Subject: [PATCH] add python asyncio-like event loop --- .gitignore | 2 +- build.sh | 12 +- context.cpp => coroutine.cpp | 12 +- context.h => coroutine.h | 5 + event_loop.h | 227 +++++++++++++++++++++++++++++++++++ future.h | 82 +++++++++++++ future_test.cpp | 25 ++++ loop_test.cpp | 47 ++++++++ main.cpp => switch_test.cpp | 2 +- tcp_client.py | 11 ++ 10 files changed, 420 insertions(+), 5 deletions(-) rename context.cpp => coroutine.cpp (59%) rename context.h => coroutine.h (89%) create mode 100644 event_loop.h create mode 100644 future.h create mode 100644 future_test.cpp create mode 100644 loop_test.cpp rename main.cpp => switch_test.cpp (96%) create mode 100644 tcp_client.py diff --git a/.gitignore b/.gitignore index e7790e6..35c6c76 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ # Prerequisites *.d -cops +bin/ # Compiled Object files *.slo diff --git a/build.sh b/build.sh index c2105ca..d817939 100755 --- a/build.sh +++ b/build.sh @@ -1,3 +1,13 @@ #!/bin/bash -g++ -g -std=c++17 -I. *.cpp *.S -o cops +CXX="/usr/local/gcc-10/bin/g++-10" +FLAGS="-g -std=c++17 -fno-sized-deallocation -D_GLIBCXX_USE_CXX11_ABI=0 -I." + +mkdir -p bin + +$CXX $FLAGS switch_test.cpp coroutine.cpp *.S -o bin/switch_test + +$CXX $FLAGS future_test.cpp -o bin/future_test + +$CXX $FLAGS -O0 loop_test.cpp coroutine.cpp *.S -o bin/loop_test + diff --git a/context.cpp b/coroutine.cpp similarity index 59% rename from context.cpp rename to coroutine.cpp index 176b694..06ff4c2 100644 --- a/context.cpp +++ b/coroutine.cpp @@ -1,5 +1,5 @@ -#include -#include +#include "coroutine.h" +#include "event_loop.h" namespace cops{ @@ -9,6 +9,7 @@ coro_t* current = &def; void main(coro_t* coro, context from) { coro->next_->ctx_ = from; coro->fn_(); + coro->fut_.set_value(0); coro->switch_out(); } @@ -25,4 +26,11 @@ void coro_t::switch_in() { ctx_ = switch_context(this, ctx_); } +void coro_t::detach(std::unique_ptr& self, event_loop_t* loop) { + // make sure destruction after coro execute over + fut_.set_callback([loop, self = self.release()]() { + loop->call_soon([self]() { delete self; }); + }); +} + } diff --git a/context.h b/coroutine.h similarity index 89% rename from context.h rename to coroutine.h index 2084e93..5cff059 100644 --- a/context.h +++ b/coroutine.h @@ -4,8 +4,11 @@ #include #include +#include + namespace cops { class coro_t; +class event_loop_t; } using context = void*; @@ -47,12 +50,14 @@ public: void switch_out(); void switch_in(); + void detach(std::unique_ptr& self, event_loop_t* loop); public: context ctx_; coro_t* next_; stack_t stack_; Fn fn_; + future_t fut_; }; extern coro_t* current; diff --git a/event_loop.h b/event_loop.h new file mode 100644 index 0000000..e8890cc --- /dev/null +++ b/event_loop.h @@ -0,0 +1,227 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "coroutine.h" + +namespace cops { + +class task_t { +public: + virtual ~task_t() = default; + virtual void run() = 0; +}; + +template +class lambda_task_t : public task_t { +public: + explicit lambda_task_t(T&& task) : task_(std::forward(task)) {} + void run() override { + task_(); + } +private: + T task_; +}; + +template +inline std::unique_ptr make_lambda_task(T&& lambda) { + return std::unique_ptr(new lambda_task_t(std::forward(lambda))); +} + +inline void oops(const std::string& s) { + perror(s.data()); + exit(1); +} + +class event_loop_t { +public: + struct epoll_data_t { + int fd; + std::unique_ptr task; + }; + static constexpr int kMaxEpollEvents = 128; + +public: + event_loop_t() = default; + ~event_loop_t() { + if (!closed_) { + stop(); + } + } + + template + void call_soon(Fn&& fn) { + task_queue_.push_back(make_lambda_task(std::forward(fn)).release()); + } + + template + void create_task(Fn&& fn) { + auto coro = make_coro(std::forward(fn)); + call_soon([coro = coro.get()]() { coro->switch_in(); }); + coro->detach(coro, this); + assert(coro.get() == nullptr); + } + + int create_server(const std::string& host, int16_t port) { + int sockfd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); + if (sockfd == 0) { + oops("socket() " + std::to_string(errno)); + } + int opt = 1; + if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt))) { + oops("setsockopt() " + std::to_string(errno)); + } + struct sockaddr_in sa; + sa.sin_family = AF_INET; + inet_pton(AF_INET, host.data(), &(sa.sin_addr)); + sa.sin_port = htons(port); + if (bind(sockfd, (struct sockaddr*)&sa, sizeof(sa)) < 0) { + oops("bind() " + std::to_string(errno)); + } + int backlog = 4096; + if (listen(sockfd, backlog) < 0) { + oops("listen() " + std::to_string(errno)); + } + return sockfd; + } + + int sock_accept(int sockfd) { + struct sockaddr_in sa; + socklen_t len = sizeof(sa); + int fd = accept4(sockfd, (struct sockaddr*)&sa, (socklen_t*)&len, SOCK_NONBLOCK); + // TODO: return std::tuple + if (fd >= 0) { + return fd; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + auto curr = current; + file_desc_add(sockfd, EPOLLIN, [curr]() { curr->switch_in(); }); + curr->switch_out(); + } + return accept4(sockfd, (struct sockaddr*)&sa, (socklen_t*)&len, SOCK_NONBLOCK); + } + + ssize_t sock_recv(int sockfd, void* buf, size_t len, int flags) { + ssize_t n = recv(sockfd, buf, len, flags); + if (n >= 0) { + return n; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + auto curr = current; + file_desc_add(sockfd, EPOLLIN, [curr]() { curr->switch_in(); }); + curr->switch_out(); + } + return recv(sockfd, buf, len, flags); + } + + ssize_t sock_send(int sockfd, void* buf, size_t len, int flags) { + ssize_t n = send(sockfd, buf, len, flags); + if (n >= 0) { + return n; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + auto curr = current; + file_desc_add(sockfd, EPOLLOUT, [curr]() { curr->switch_in(); }); + curr->switch_out(); + } + return send(sockfd, buf, len, flags); + } + +public: + void run_forever() { + epollfd_ = epoll_create(kMaxEpollEvents); + if (epollfd_ < 0) { + oops("epoll_create() " + std::to_string(errno)); + } + eventfd_ = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + if (eventfd_ < 0) { + oops("eventfd() " + std::to_string(errno)); + } + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.ptr = nullptr; + if (epoll_ctl(epollfd_, EPOLL_CTL_ADD, eventfd_, &ev) < 0) { + oops("epoll_ctl() " + std::to_string(errno)); + } + + closed_ = false; + while (!closed_) { + int timeout = task_queue_.empty() ? -1 : 0; + struct epoll_event events[kMaxEpollEvents]; + int n = epoll_wait(epollfd_, events, kMaxEpollEvents, timeout); + if (n < 0) { + oops("epoll_wait() " + std::to_string(errno)); + } + for (int i = 0; i < n; ++i) { + auto data = static_cast(events[i].data.ptr); + // response wakeup + if (!data) { + static char _unused[8]; + read(eventfd_, _unused, 8); + continue; + } + // invoke io callback + file_desc_del(data->fd); + data->task->run(); + delete data; + } + // schedule normal task + if (!task_queue_.empty()) { + std::unique_ptr task(task_queue_.front()); + task_queue_.pop_front(); + task->run(); + } + } + file_desc_del(eventfd_); + close(eventfd_); + close(epollfd_); + eventfd_ = -1; + epollfd_ = -1; + } + void stop() { + closed_ = true; + } + +private: + void wakeup() { + if (eventfd_ > 0) { + static uint64_t one = 1; + write(eventfd_, &one, sizeof(one)); + } + } + + template + void file_desc_add(int fd, int flags, Callback callback) { + struct epoll_event ev; + ev.events = flags; + auto task = make_lambda_task(std::forward(callback)); + ev.data.ptr = new epoll_data_t{fd, std::move(task)}; + if (epoll_ctl(epollfd_, EPOLL_CTL_ADD, fd, &ev) < 0) { + oops("epoll_ctl() " + std::to_string(errno)); + } + } + + void file_desc_del(int fd) { + if (epoll_ctl(epollfd_, EPOLL_CTL_DEL, fd, nullptr) < 0) { + oops("epoll_ctl() " + std::to_string(errno)); + } + } + +private: + bool closed_ = true; + int epollfd_ = -1; + int eventfd_ = -1; + std::deque task_queue_; +}; + +} // cops diff --git a/future.h b/future.h new file mode 100644 index 0000000..63ea34b --- /dev/null +++ b/future.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include + +namespace cops { + +template +class future_t { +public: + enum class status_t { + kInit = 0, + kHasValue, + kHasCallback, + kDone, + }; + + using Fn = std::function; + +public: + future_t() = default; + ~future_t() = default; + + void set_value(T&& value) { + auto s = status_.load(std::memory_order_acquire); + if (s == status_t::kHasValue || s == status_t::kDone) { + return; + } + value_ = std::move(value); + if (s == status_t::kHasCallback) { + invoke(); + return; + } + s = status_.exchange(status_t::kHasValue, std::memory_order_acq_rel); + // double check, maybe exchange after set_callback + if (s == status_t::kHasCallback) { + invoke(); + } + } + + T&& value() { + return std::move(value_); + } + + bool has_value() { + return status_t::kHasValue == status_.load(std::memory_order_acquire); + } + + void set_callback(Fn&& fn) { + auto s = status_.load(std::memory_order_acquire); + if (s == status_t::kHasCallback || s == status_t::kDone) { + return; + } + callback_ = std::forward(fn); + if (s == status_t::kHasValue) { + invoke(); + return; + } + s = status_.exchange(status_t::kHasCallback, std::memory_order_acq_rel); + // double check, maybe exchange after set_value + if (s == status_t::kHasValue) { + invoke(); + } + } + +private: + void set_done() { + status_.store(status_t::kDone, std::memory_order_release); + } + void invoke() { + set_done(); + callback_(); + } + +private: + T value_; + std::atomic status_ = status_t::kInit; + Fn callback_; +}; + +} // cops diff --git a/future_test.cpp b/future_test.cpp new file mode 100644 index 0000000..bf48ea1 --- /dev/null +++ b/future_test.cpp @@ -0,0 +1,25 @@ +#include +#include "future.h" + +int main() { + auto future = std::make_shared>(); + std::cout << future->has_value() << std::endl; + future->set_value(100); + std::cout << future->has_value() << std::endl; + future->set_callback([]() { + std::cout << "callback" << std::endl; + }); + future->set_value(101); + future->set_callback([]() { + std::cout << "callback" << std::endl; + }); + std::cout << future->value() << std::endl; + + auto future2 = std::make_shared>(); + future2->set_callback([]() { + std::cout << "callback2" << std::endl; + }); + future2->set_value(10.1); + std::cout << future2->value() << std::endl; + return 0; +} diff --git a/loop_test.cpp b/loop_test.cpp new file mode 100644 index 0000000..44d2bc2 --- /dev/null +++ b/loop_test.cpp @@ -0,0 +1,47 @@ +#include +#include "coroutine.h" +#include "event_loop.h" + +int main() { + auto loop = std::make_unique(); + loop->call_soon([&loop]() { + std::cout << "task1" << std::endl; + loop->call_soon([&loop]() { + std::cout << "task2" << std::endl; + }); + }); + + loop->create_task([&loop]() { + int server = loop->create_server("127.0.0.1", 9999); + std::cout << "create server " << server << std::endl; + while (1) { + int conn = loop->sock_accept(server); + std::cout << "server " << server << " new conn " << conn << std::endl; + loop->create_task([&loop, conn]() { + char buf[128]; + ssize_t n = loop->sock_recv(conn, buf, 128, 0); + std::cout << buf << std::endl; + n = loop->sock_send(conn, buf, n, 0); + close(conn); + }); + } + }); + + loop->create_task([&loop]() { + int server = loop->create_server("127.0.0.1", 9998); + std::cout << "create server " << server << std::endl; + while (1) { + int conn = loop->sock_accept(server); + std::cout << "server " << server << " new conn " << conn << std::endl; + loop->create_task([&loop, conn]() { + char buf[128]; + ssize_t n = loop->sock_recv(conn, buf, 128, 0); + std::cout << buf << std::endl; + n = loop->sock_send(conn, buf, n, 0); + close(conn); + }); + } + }); + loop->run_forever(); + return 0; +} diff --git a/main.cpp b/switch_test.cpp similarity index 96% rename from main.cpp rename to switch_test.cpp index 4d4a10d..703a6a3 100644 --- a/main.cpp +++ b/switch_test.cpp @@ -1,5 +1,5 @@ #include -#include +#include "coroutine.h" int main() { std::unique_ptr coro; diff --git a/tcp_client.py b/tcp_client.py new file mode 100644 index 0000000..2a41c5b --- /dev/null +++ b/tcp_client.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 + +import sys +import socket + +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect((sys.argv[1], int(sys.argv[2]))) + s.sendall(b'hello') + data = s.recv(1024) + print(data) +