Adds writer test case for RCU
[folly.git] / folly / io / async / test / TestSSLServer.h
1 /*
2  * Copyright 2017-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/SocketAddress.h>
19 #include <folly/experimental/TestUtil.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSocket.h>
23 #include <folly/io/async/AsyncTimeout.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/ssl/SSLErrors.h>
27 #include <folly/portability/GTest.h>
28 #include <folly/portability/Sockets.h>
29 #include <folly/portability/Unistd.h>
30
31 #include <fcntl.h>
32 #include <sys/types.h>
33 #include <list>
34
35 namespace folly {
36
37 extern const char* kTestCert;
38 extern const char* kTestKey;
39 extern const char* kTestCA;
40
41 extern const char* kClientTestCert;
42 extern const char* kClientTestKey;
43 extern const char* kClientTestCA;
44
45 enum StateEnum { STATE_WAITING, STATE_SUCCEEDED, STATE_FAILED };
46
47 class HandshakeCallback;
48
49 class SSLServerAcceptCallbackBase : public AsyncServerSocket::AcceptCallback {
50  public:
51   explicit SSLServerAcceptCallbackBase(HandshakeCallback* hcb)
52       : state(STATE_WAITING), hcb_(hcb) {}
53
54   ~SSLServerAcceptCallbackBase() override {
55     EXPECT_EQ(STATE_SUCCEEDED, state);
56   }
57
58   void acceptError(const std::exception& ex) noexcept override {
59     LOG(WARNING) << "SSLServerAcceptCallbackBase::acceptError " << ex.what();
60     state = STATE_FAILED;
61   }
62
63   void connectionAccepted(
64       int fd,
65       const SocketAddress& /* clientAddr */) noexcept override {
66     if (socket_) {
67       socket_->detachEventBase();
68     }
69     LOG(INFO) << "Connection accepted";
70     try {
71       // Create a AsyncSSLSocket object with the fd. The socket should be
72       // added to the event base and in the state of accepting SSL connection.
73       socket_ = AsyncSSLSocket::newSocket(ctx_, base_, fd);
74     } catch (const std::exception& e) {
75       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
76                     "object with socket "
77                  << e.what() << fd;
78       ::close(fd);
79       acceptError(e);
80       return;
81     }
82
83     connAccepted(socket_);
84   }
85
86   virtual void connAccepted(const std::shared_ptr<AsyncSSLSocket>& s) = 0;
87
88   void detach() {
89     socket_->detachEventBase();
90   }
91
92   StateEnum state;
93   HandshakeCallback* hcb_;
94   std::shared_ptr<SSLContext> ctx_;
95   std::shared_ptr<AsyncSSLSocket> socket_;
96   EventBase* base_;
97 };
98
99 class TestSSLServer {
100  public:
101   // Create a TestSSLServer.
102   // This immediately starts listening on the given port.
103   explicit TestSSLServer(
104       SSLServerAcceptCallbackBase* acb,
105       bool enableTFO = false);
106   explicit TestSSLServer(
107       SSLServerAcceptCallbackBase* acb,
108       std::shared_ptr<SSLContext> ctx,
109       bool enableTFO = false);
110
111   // Kills the thread.
112   virtual ~TestSSLServer();
113
114   EventBase& getEventBase() {
115     return evb_;
116   }
117
118   void loadTestCerts();
119
120   const SocketAddress& getAddress() const {
121     return address_;
122   }
123
124  protected:
125   EventBase evb_;
126   std::shared_ptr<SSLContext> ctx_;
127   SSLServerAcceptCallbackBase* acb_;
128   std::shared_ptr<AsyncServerSocket> socket_;
129   SocketAddress address_;
130   std::thread thread_;
131
132  private:
133   void init(bool);
134 };
135 } // namespace folly