Move AsyncSocket tests from thrift to folly
[folly.git] / folly / io / test / ShutdownSocketSetTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/ShutdownSocketSet.h>
17
18 #include <atomic>
19 #include <chrono>
20 #include <thread>
21
22 #include <netinet/in.h>
23 #include <netinet/tcp.h>
24 #include <sys/socket.h>
25
26 #include <glog/logging.h>
27 #include <gtest/gtest.h>
28
29 using folly::ShutdownSocketSet;
30
31 namespace folly { namespace test {
32
33 ShutdownSocketSet shutdownSocketSet;
34
35 class Server {
36  public:
37   Server();
38
39   void stop(bool abortive);
40   void join();
41   int port() const { return port_; }
42   int closeClients(bool abortive);
43
44  private:
45   int acceptSocket_;
46   int port_;
47   enum StopMode {
48     NO_STOP,
49     ORDERLY,
50     ABORTIVE
51   };
52   std::atomic<StopMode> stop_;
53   std::thread serverThread_;
54   std::vector<int> fds_;
55 };
56
57 Server::Server()
58   : acceptSocket_(-1),
59     port_(0),
60     stop_(NO_STOP) {
61   acceptSocket_ = socket(PF_INET, SOCK_STREAM, 0);
62   CHECK_ERR(acceptSocket_);
63   shutdownSocketSet.add(acceptSocket_);
64
65   sockaddr_in addr;
66   addr.sin_family = AF_INET;
67   addr.sin_port = 0;
68   addr.sin_addr.s_addr = INADDR_ANY;
69   CHECK_ERR(bind(acceptSocket_,
70                  reinterpret_cast<const sockaddr*>(&addr),
71                  sizeof(addr)));
72
73   CHECK_ERR(listen(acceptSocket_, 10));
74
75   socklen_t addrLen = sizeof(addr);
76   CHECK_ERR(getsockname(acceptSocket_,
77                         reinterpret_cast<sockaddr*>(&addr),
78                         &addrLen));
79
80   port_ = ntohs(addr.sin_port);
81
82   serverThread_ = std::thread([this] {
83     while (stop_ == NO_STOP) {
84       sockaddr_in peer;
85       socklen_t peerLen = sizeof(peer);
86       int fd = accept(acceptSocket_,
87                       reinterpret_cast<sockaddr*>(&peer),
88                       &peerLen);
89       if (fd == -1) {
90         if (errno == EINTR) {
91           continue;
92         }
93         if (errno == EINVAL || errno == ENOTSOCK) {  // socket broken
94           break;
95         }
96       }
97       CHECK_ERR(fd);
98       shutdownSocketSet.add(fd);
99       fds_.push_back(fd);
100     }
101
102     if (stop_ != NO_STOP) {
103       closeClients(stop_ == ABORTIVE);
104     }
105
106     shutdownSocketSet.close(acceptSocket_);
107     acceptSocket_ = -1;
108     port_ = 0;
109   });
110 }
111
112 int Server::closeClients(bool abortive) {
113   for (int fd : fds_) {
114     if (abortive) {
115       struct linger l = {1, 0};
116       CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
117     }
118     shutdownSocketSet.close(fd);
119   }
120   int n = fds_.size();
121   fds_.clear();
122   return n;
123 }
124
125 void Server::stop(bool abortive) {
126   stop_ = abortive ? ABORTIVE : ORDERLY;
127   shutdown(acceptSocket_, SHUT_RDWR);
128 }
129
130 void Server::join() {
131   serverThread_.join();
132 }
133
134 int createConnectedSocket(int port) {
135   int sock = socket(PF_INET, SOCK_STREAM, 0);
136   CHECK_ERR(sock);
137   sockaddr_in addr;
138   addr.sin_family = AF_INET;
139   addr.sin_port = htons(port);
140   addr.sin_addr.s_addr = htonl((127 << 24) | 1);  // XXX
141   CHECK_ERR(connect(sock,
142                     reinterpret_cast<const sockaddr*>(&addr),
143                     sizeof(addr)));
144   return sock;
145 }
146
147 void runCloseTest(bool abortive) {
148   Server server;
149
150   int sock = createConnectedSocket(server.port());
151
152   std::thread stopper([&server, abortive] {
153     std::this_thread::sleep_for(std::chrono::milliseconds(200));
154     server.stop(abortive);
155     server.join();
156   });
157
158   char c;
159   int r = read(sock, &c, 1);
160   if (abortive) {
161     int e = errno;
162     EXPECT_EQ(-1, r);
163     EXPECT_EQ(ECONNRESET, e);
164   } else {
165     EXPECT_EQ(0, r);
166   }
167
168   close(sock);
169
170   stopper.join();
171
172   EXPECT_EQ(0, server.closeClients(false));  // closed by server when it exited
173 }
174
175 TEST(ShutdownSocketSetTest, OrderlyClose) {
176   runCloseTest(false);
177 }
178
179 TEST(ShutdownSocketSetTest, AbortiveClose) {
180   runCloseTest(true);
181 }
182
183 void runKillTest(bool abortive) {
184   Server server;
185
186   int sock = createConnectedSocket(server.port());
187
188   std::thread killer([&server, abortive] {
189     std::this_thread::sleep_for(std::chrono::milliseconds(200));
190     shutdownSocketSet.shutdownAll(abortive);
191     server.join();
192   });
193
194   char c;
195   int r = read(sock, &c, 1);
196
197   // "abortive" is just a hint for ShutdownSocketSet, so accept both
198   // behaviors
199   if (abortive) {
200     if (r == -1) {
201       EXPECT_EQ(ECONNRESET, errno);
202     } else {
203       EXPECT_EQ(r, 0);
204     }
205   } else {
206     EXPECT_EQ(0, r);
207   }
208
209   close(sock);
210
211   killer.join();
212
213   // NOT closed by server when it exited
214   EXPECT_EQ(1, server.closeClients(false));
215 }
216
217 TEST(ShutdownSocketSetTest, OrderlyKill) {
218   runKillTest(false);
219 }
220
221 TEST(ShutdownSocketSetTest, AbortiveKill) {
222   runKillTest(true);
223 }
224
225 }}  // namespaces
226
227 int main(int argc, char *argv[]) {
228   testing::InitGoogleTest(&argc, argv);
229   google::ParseCommandLineFlags(&argc, &argv, true);
230   return RUN_ALL_TESTS();
231 }