Adds writer test case for RCU
[folly.git] / folly / io / test / ShutdownSocketSetTest.cpp
1 /*
2  * Copyright 2013-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/ShutdownSocketSet.h>
17
18 #include <atomic>
19 #include <chrono>
20 #include <thread>
21
22 #include <glog/logging.h>
23
24 #include <folly/portability/GTest.h>
25 #include <folly/portability/Sockets.h>
26
27 using folly::ShutdownSocketSet;
28
29 namespace fsp = folly::portability::sockets;
30
31 namespace folly {
32 namespace test {
33
34 ShutdownSocketSet shutdownSocketSet;
35
36 class Server {
37  public:
38   Server();
39
40   void stop(bool abortive);
41   void join();
42   int port() const { return port_; }
43   int closeClients(bool abortive);
44
45  private:
46   int acceptSocket_;
47   int port_;
48   enum StopMode {
49     NO_STOP,
50     ORDERLY,
51     ABORTIVE
52   };
53   std::atomic<StopMode> stop_;
54   std::thread serverThread_;
55   std::vector<int> fds_;
56 };
57
58 Server::Server()
59   : acceptSocket_(-1),
60     port_(0),
61     stop_(NO_STOP) {
62   acceptSocket_ = fsp::socket(PF_INET, SOCK_STREAM, 0);
63   CHECK_ERR(acceptSocket_);
64   shutdownSocketSet.add(acceptSocket_);
65
66   sockaddr_in addr;
67   addr.sin_family = AF_INET;
68   addr.sin_port = 0;
69   addr.sin_addr.s_addr = INADDR_ANY;
70   CHECK_ERR(bind(acceptSocket_,
71                  reinterpret_cast<const sockaddr*>(&addr),
72                  sizeof(addr)));
73
74   CHECK_ERR(listen(acceptSocket_, 10));
75
76   socklen_t addrLen = sizeof(addr);
77   CHECK_ERR(getsockname(acceptSocket_,
78                         reinterpret_cast<sockaddr*>(&addr),
79                         &addrLen));
80
81   port_ = ntohs(addr.sin_port);
82
83   serverThread_ = std::thread([this] {
84     while (stop_ == NO_STOP) {
85       sockaddr_in peer;
86       socklen_t peerLen = sizeof(peer);
87       int fd = accept(acceptSocket_,
88                       reinterpret_cast<sockaddr*>(&peer),
89                       &peerLen);
90       if (fd == -1) {
91         if (errno == EINTR) {
92           continue;
93         }
94         if (errno == EINVAL || errno == ENOTSOCK) {  // socket broken
95           break;
96         }
97       }
98       CHECK_ERR(fd);
99       shutdownSocketSet.add(fd);
100       fds_.push_back(fd);
101     }
102
103     if (stop_ != NO_STOP) {
104       closeClients(stop_ == ABORTIVE);
105     }
106
107     shutdownSocketSet.close(acceptSocket_);
108     acceptSocket_ = -1;
109     port_ = 0;
110   });
111 }
112
113 int Server::closeClients(bool abortive) {
114   for (int fd : fds_) {
115     if (abortive) {
116       struct linger l = {1, 0};
117       CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
118     }
119     shutdownSocketSet.close(fd);
120   }
121   int n = fds_.size();
122   fds_.clear();
123   return n;
124 }
125
126 void Server::stop(bool abortive) {
127   stop_ = abortive ? ABORTIVE : ORDERLY;
128   shutdown(acceptSocket_, SHUT_RDWR);
129 }
130
131 void Server::join() {
132   serverThread_.join();
133 }
134
135 int createConnectedSocket(int port) {
136   int sock = fsp::socket(PF_INET, SOCK_STREAM, 0);
137   CHECK_ERR(sock);
138   sockaddr_in addr;
139   addr.sin_family = AF_INET;
140   addr.sin_port = htons(port);
141   addr.sin_addr.s_addr = htonl((127 << 24) | 1);  // XXX
142   CHECK_ERR(connect(sock,
143                     reinterpret_cast<const sockaddr*>(&addr),
144                     sizeof(addr)));
145   return sock;
146 }
147
148 void runCloseTest(bool abortive) {
149   Server server;
150
151   int sock = createConnectedSocket(server.port());
152
153   std::thread stopper([&server, abortive] {
154     std::this_thread::sleep_for(std::chrono::milliseconds(200));
155     server.stop(abortive);
156     server.join();
157   });
158
159   char c;
160   int r = read(sock, &c, 1);
161   if (abortive) {
162     int e = errno;
163     EXPECT_EQ(-1, r);
164     EXPECT_EQ(ECONNRESET, e);
165   } else {
166     EXPECT_EQ(0, r);
167   }
168
169   close(sock);
170
171   stopper.join();
172
173   EXPECT_EQ(0, server.closeClients(false));  // closed by server when it exited
174 }
175
176 TEST(ShutdownSocketSetTest, OrderlyClose) {
177   runCloseTest(false);
178 }
179
180 TEST(ShutdownSocketSetTest, AbortiveClose) {
181   runCloseTest(true);
182 }
183
184 void runKillTest(bool abortive) {
185   Server server;
186
187   int sock = createConnectedSocket(server.port());
188
189   std::thread killer([&server, abortive] {
190     std::this_thread::sleep_for(std::chrono::milliseconds(200));
191     shutdownSocketSet.shutdownAll(abortive);
192     server.join();
193   });
194
195   char c;
196   int r = read(sock, &c, 1);
197
198   // "abortive" is just a hint for ShutdownSocketSet, so accept both
199   // behaviors
200   if (abortive) {
201     if (r == -1) {
202       EXPECT_EQ(ECONNRESET, errno);
203     } else {
204       EXPECT_EQ(r, 0);
205     }
206   } else {
207     EXPECT_EQ(0, r);
208   }
209
210   close(sock);
211
212   killer.join();
213
214   // NOT closed by server when it exited
215   EXPECT_EQ(1, server.closeClients(false));
216 }
217
218 TEST(ShutdownSocketSetTest, OrderlyKill) {
219   runKillTest(false);
220 }
221
222 TEST(ShutdownSocketSetTest, AbortiveKill) {
223   runKillTest(true);
224 }
225 } // namespace test
226 } // namespace folly