Add TLS 1.2+ version for contexts
[folly.git] / folly / io / async / test / TestSSLServer.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/SocketAddress.h>
19 #include <folly/experimental/TestUtil.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSocket.h>
23 #include <folly/io/async/AsyncTimeout.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/ssl/SSLErrors.h>
27 #include <folly/portability/GTest.h>
28 #include <folly/portability/Sockets.h>
29 #include <folly/portability/Unistd.h>
30
31 #include <fcntl.h>
32 #include <sys/types.h>
33 #include <list>
34
35 namespace folly {
36
37 extern const char* kTestCert;
38 extern const char* kTestKey;
39 extern const char* kTestCA;
40
41 enum StateEnum { STATE_WAITING, STATE_SUCCEEDED, STATE_FAILED };
42
43 class HandshakeCallback;
44
45 class SSLServerAcceptCallbackBase : public AsyncServerSocket::AcceptCallback {
46  public:
47   explicit SSLServerAcceptCallbackBase(HandshakeCallback* hcb)
48       : state(STATE_WAITING), hcb_(hcb) {}
49
50   ~SSLServerAcceptCallbackBase() override {
51     EXPECT_EQ(STATE_SUCCEEDED, state);
52   }
53
54   void acceptError(const std::exception& ex) noexcept override {
55     LOG(WARNING) << "SSLServerAcceptCallbackBase::acceptError " << ex.what();
56     state = STATE_FAILED;
57   }
58
59   void connectionAccepted(
60       int fd,
61       const SocketAddress& /* clientAddr */) noexcept override {
62     if (socket_) {
63       socket_->detachEventBase();
64     }
65     LOG(INFO) << "Connection accepted";
66     try {
67       // Create a AsyncSSLSocket object with the fd. The socket should be
68       // added to the event base and in the state of accepting SSL connection.
69       socket_ = AsyncSSLSocket::newSocket(ctx_, base_, fd);
70     } catch (const std::exception& e) {
71       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
72                     "object with socket "
73                  << e.what() << fd;
74       ::close(fd);
75       acceptError(e);
76       return;
77     }
78
79     connAccepted(socket_);
80   }
81
82   virtual void connAccepted(const std::shared_ptr<AsyncSSLSocket>& s) = 0;
83
84   void detach() {
85     socket_->detachEventBase();
86   }
87
88   StateEnum state;
89   HandshakeCallback* hcb_;
90   std::shared_ptr<SSLContext> ctx_;
91   std::shared_ptr<AsyncSSLSocket> socket_;
92   EventBase* base_;
93 };
94
95 class TestSSLServer {
96  public:
97   // Create a TestSSLServer.
98   // This immediately starts listening on the given port.
99   explicit TestSSLServer(
100       SSLServerAcceptCallbackBase* acb,
101       bool enableTFO = false);
102   explicit TestSSLServer(
103       SSLServerAcceptCallbackBase* acb,
104       std::shared_ptr<SSLContext> ctx,
105       bool enableTFO = false);
106
107   // Kills the thread.
108   virtual ~TestSSLServer();
109
110   EventBase& getEventBase() {
111     return evb_;
112   }
113
114   void loadTestCerts();
115
116   const SocketAddress& getAddress() const {
117     return address_;
118   }
119
120  protected:
121   EventBase evb_;
122   std::shared_ptr<SSLContext> ctx_;
123   SSLServerAcceptCallbackBase* acb_;
124   std::shared_ptr<AsyncServerSocket> socket_;
125   SocketAddress address_;
126   std::thread thread_;
127
128  private:
129   void init(bool);
130 };
131 }