Switch uses of networking headers to <folly/portability/Sockets.h>
[folly.git] / folly / io / test / ShutdownSocketSetTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/ShutdownSocketSet.h>
17
18 #include <atomic>
19 #include <chrono>
20 #include <thread>
21
22 #include <glog/logging.h>
23 #include <gtest/gtest.h>
24
25 #include <folly/portability/Sockets.h>
26
27 using folly::ShutdownSocketSet;
28
29 namespace folly { namespace test {
30
31 ShutdownSocketSet shutdownSocketSet;
32
33 class Server {
34  public:
35   Server();
36
37   void stop(bool abortive);
38   void join();
39   int port() const { return port_; }
40   int closeClients(bool abortive);
41
42  private:
43   int acceptSocket_;
44   int port_;
45   enum StopMode {
46     NO_STOP,
47     ORDERLY,
48     ABORTIVE
49   };
50   std::atomic<StopMode> stop_;
51   std::thread serverThread_;
52   std::vector<int> fds_;
53 };
54
55 Server::Server()
56   : acceptSocket_(-1),
57     port_(0),
58     stop_(NO_STOP) {
59   acceptSocket_ = socket(PF_INET, SOCK_STREAM, 0);
60   CHECK_ERR(acceptSocket_);
61   shutdownSocketSet.add(acceptSocket_);
62
63   sockaddr_in addr;
64   addr.sin_family = AF_INET;
65   addr.sin_port = 0;
66   addr.sin_addr.s_addr = INADDR_ANY;
67   CHECK_ERR(bind(acceptSocket_,
68                  reinterpret_cast<const sockaddr*>(&addr),
69                  sizeof(addr)));
70
71   CHECK_ERR(listen(acceptSocket_, 10));
72
73   socklen_t addrLen = sizeof(addr);
74   CHECK_ERR(getsockname(acceptSocket_,
75                         reinterpret_cast<sockaddr*>(&addr),
76                         &addrLen));
77
78   port_ = ntohs(addr.sin_port);
79
80   serverThread_ = std::thread([this] {
81     while (stop_ == NO_STOP) {
82       sockaddr_in peer;
83       socklen_t peerLen = sizeof(peer);
84       int fd = accept(acceptSocket_,
85                       reinterpret_cast<sockaddr*>(&peer),
86                       &peerLen);
87       if (fd == -1) {
88         if (errno == EINTR) {
89           continue;
90         }
91         if (errno == EINVAL || errno == ENOTSOCK) {  // socket broken
92           break;
93         }
94       }
95       CHECK_ERR(fd);
96       shutdownSocketSet.add(fd);
97       fds_.push_back(fd);
98     }
99
100     if (stop_ != NO_STOP) {
101       closeClients(stop_ == ABORTIVE);
102     }
103
104     shutdownSocketSet.close(acceptSocket_);
105     acceptSocket_ = -1;
106     port_ = 0;
107   });
108 }
109
110 int Server::closeClients(bool abortive) {
111   for (int fd : fds_) {
112     if (abortive) {
113       struct linger l = {1, 0};
114       CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
115     }
116     shutdownSocketSet.close(fd);
117   }
118   int n = fds_.size();
119   fds_.clear();
120   return n;
121 }
122
123 void Server::stop(bool abortive) {
124   stop_ = abortive ? ABORTIVE : ORDERLY;
125   shutdown(acceptSocket_, SHUT_RDWR);
126 }
127
128 void Server::join() {
129   serverThread_.join();
130 }
131
132 int createConnectedSocket(int port) {
133   int sock = socket(PF_INET, SOCK_STREAM, 0);
134   CHECK_ERR(sock);
135   sockaddr_in addr;
136   addr.sin_family = AF_INET;
137   addr.sin_port = htons(port);
138   addr.sin_addr.s_addr = htonl((127 << 24) | 1);  // XXX
139   CHECK_ERR(connect(sock,
140                     reinterpret_cast<const sockaddr*>(&addr),
141                     sizeof(addr)));
142   return sock;
143 }
144
145 void runCloseTest(bool abortive) {
146   Server server;
147
148   int sock = createConnectedSocket(server.port());
149
150   std::thread stopper([&server, abortive] {
151     std::this_thread::sleep_for(std::chrono::milliseconds(200));
152     server.stop(abortive);
153     server.join();
154   });
155
156   char c;
157   int r = read(sock, &c, 1);
158   if (abortive) {
159     int e = errno;
160     EXPECT_EQ(-1, r);
161     EXPECT_EQ(ECONNRESET, e);
162   } else {
163     EXPECT_EQ(0, r);
164   }
165
166   close(sock);
167
168   stopper.join();
169
170   EXPECT_EQ(0, server.closeClients(false));  // closed by server when it exited
171 }
172
173 TEST(ShutdownSocketSetTest, OrderlyClose) {
174   runCloseTest(false);
175 }
176
177 TEST(ShutdownSocketSetTest, AbortiveClose) {
178   runCloseTest(true);
179 }
180
181 void runKillTest(bool abortive) {
182   Server server;
183
184   int sock = createConnectedSocket(server.port());
185
186   std::thread killer([&server, abortive] {
187     std::this_thread::sleep_for(std::chrono::milliseconds(200));
188     shutdownSocketSet.shutdownAll(abortive);
189     server.join();
190   });
191
192   char c;
193   int r = read(sock, &c, 1);
194
195   // "abortive" is just a hint for ShutdownSocketSet, so accept both
196   // behaviors
197   if (abortive) {
198     if (r == -1) {
199       EXPECT_EQ(ECONNRESET, errno);
200     } else {
201       EXPECT_EQ(r, 0);
202     }
203   } else {
204     EXPECT_EQ(0, r);
205   }
206
207   close(sock);
208
209   killer.join();
210
211   // NOT closed by server when it exited
212   EXPECT_EQ(1, server.closeClients(false));
213 }
214
215 TEST(ShutdownSocketSetTest, OrderlyKill) {
216   runKillTest(false);
217 }
218
219 TEST(ShutdownSocketSetTest, AbortiveKill) {
220   runKillTest(true);
221 }
222
223 }}  // namespaces