Make RequestContext provider overridable in order to save cost of setContext() on...
[folly.git] / folly / io / async / test / RequestContextTest.cpp
1 /*
2  * Copyright 2004-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <thread>
18
19 #include <folly/Memory.h>
20 #include <folly/io/async/EventBase.h>
21 #include <folly/io/async/Request.h>
22 #include <folly/portability/GTest.h>
23
24 using namespace folly;
25
26 class TestData : public RequestData {
27  public:
28   explicit TestData(int data) : data_(data) {}
29   ~TestData() override {}
30   void onSet() override {
31     set_++;
32   }
33   void onUnset() override {
34     unset_++;
35   }
36   int set_ = 0, unset_ = 0;
37   int data_;
38 };
39
40 TEST(RequestContext, SimpleTest) {
41   EventBase base;
42
43
44   // There should always be a default context with get()
45   EXPECT_TRUE(RequestContext::get() != nullptr);
46
47
48   // but not with saveContext()
49   EXPECT_EQ(RequestContext::saveContext(), nullptr);
50   RequestContext::create();
51   EXPECT_NE(RequestContext::saveContext(), nullptr);
52   RequestContext::create();
53   EXPECT_NE(RequestContext::saveContext(), nullptr);
54
55   EXPECT_EQ(nullptr, RequestContext::get()->getContextData("test"));
56
57   RequestContext::get()->setContextData("test", std::make_unique<TestData>(10));
58   base.runInEventBaseThread([&](){
59       EXPECT_TRUE(RequestContext::get() != nullptr);
60       auto data = dynamic_cast<TestData*>(
61         RequestContext::get()->getContextData("test"))->data_;
62       EXPECT_EQ(10, data);
63       base.terminateLoopSoon();
64     });
65   auto th = std::thread([&](){
66       base.loopForever();
67   });
68   th.join();
69   EXPECT_TRUE(RequestContext::get() != nullptr);
70   auto a = dynamic_cast<TestData*>(
71     RequestContext::get()->getContextData("test"));
72   auto data = a->data_;
73   EXPECT_EQ(10, data);
74
75   RequestContext::setContext(std::shared_ptr<RequestContext>());
76   // There should always be a default context
77   EXPECT_TRUE(nullptr != RequestContext::get());
78 }
79
80 TEST(RequestContext, nonDefaultContextsAreThreadLocal) {
81   RequestContext* ctx1 = nullptr;
82   RequestContext* ctx2 = nullptr;
83
84   std::vector<std::thread> ts;
85   for (size_t i = 0; i < 2; ++i) {
86     auto*& ctx = (i == 0 ? ctx1 : ctx2);
87     ts.emplace_back([&ctx]() {
88       RequestContext::create();
89       ctx = RequestContext::get();
90     });
91   }
92   for (auto& t : ts) {
93     t.join();
94   }
95
96   EXPECT_NE(nullptr, ctx1);
97   EXPECT_NE(nullptr, ctx2);
98   EXPECT_NE(ctx1, ctx2);
99 }
100
101 TEST(RequestContext, customRequestContextProvider) {
102   auto customContext = std::make_shared<RequestContext>();
103   auto customProvider = [&customContext]() -> std::shared_ptr<RequestContext>& {
104     return customContext;
105   };
106
107   auto* const originalContext = RequestContext::get();
108   EXPECT_NE(nullptr, originalContext);
109
110   // Install new RequestContext provider
111   auto originalProvider =
112       RequestContext::setRequestContextProvider(std::move(customProvider));
113
114   auto* const newContext = RequestContext::get();
115   EXPECT_EQ(customContext.get(), newContext);
116   EXPECT_NE(originalContext, newContext);
117
118   // Restore original RequestContext provider
119   RequestContext::setRequestContextProvider(std::move(originalProvider));
120   EXPECT_EQ(originalContext, RequestContext::get());
121 }
122
123 TEST(RequestContext, setIfAbsentTest) {
124   EXPECT_TRUE(RequestContext::get() != nullptr);
125
126   RequestContext::get()->setContextData("test", std::make_unique<TestData>(10));
127   EXPECT_FALSE(RequestContext::get()->setContextDataIfAbsent(
128       "test", std::make_unique<TestData>(20)));
129   EXPECT_EQ(10,
130             dynamic_cast<TestData*>(
131                 RequestContext::get()->getContextData("test"))->data_);
132
133   EXPECT_TRUE(RequestContext::get()->setContextDataIfAbsent(
134       "test2", std::make_unique<TestData>(20)));
135   EXPECT_EQ(20,
136             dynamic_cast<TestData*>(
137                 RequestContext::get()->getContextData("test2"))->data_);
138
139   RequestContext::setContext(std::shared_ptr<RequestContext>());
140   EXPECT_TRUE(nullptr != RequestContext::get());
141 }
142
143 TEST(RequestContext, testSetUnset) {
144   RequestContext::create();
145   auto ctx1 = RequestContext::saveContext();
146   ctx1->setContextData("test", std::make_unique<TestData>(10));
147   auto testData1 = dynamic_cast<TestData*>(ctx1->getContextData("test"));
148
149   // Override RequestContext
150   RequestContext::create();
151   auto ctx2 = RequestContext::saveContext();
152   ctx2->setContextData("test", std::make_unique<TestData>(20));
153   auto testData2 = dynamic_cast<TestData*>(ctx2->getContextData("test"));
154
155   // Check ctx1->onUnset was called
156   EXPECT_EQ(0, testData1->set_);
157   EXPECT_EQ(1, testData1->unset_);
158
159   RequestContext::setContext(ctx1);
160   EXPECT_EQ(1, testData1->set_);
161   EXPECT_EQ(1, testData1->unset_);
162   EXPECT_EQ(0, testData2->set_);
163   EXPECT_EQ(1, testData2->unset_);
164
165   RequestContext::setContext(ctx2);
166   EXPECT_EQ(1, testData1->set_);
167   EXPECT_EQ(2, testData1->unset_);
168   EXPECT_EQ(1, testData2->set_);
169   EXPECT_EQ(1, testData2->unset_);
170 }
171
172 TEST(RequestContext, deadlockTest) {
173   class DeadlockTestData : public RequestData {
174    public:
175     explicit DeadlockTestData(const std::string& val) : val_(val) {}
176
177     ~DeadlockTestData() override {
178       RequestContext::get()->setContextData(
179           val_, std::make_unique<TestData>(1));
180     }
181
182     void onSet() override {}
183
184     void onUnset() override {}
185
186     std::string val_;
187   };
188
189   RequestContext::get()->setContextData(
190       "test", std::make_unique<DeadlockTestData>("test2"));
191   RequestContext::get()->clearContextData("test");
192 }