ab5c3c2c89383d5b26ab916c2508134bd34629f7
[firefly-linux-kernel-4.4.55.git] / tools / gator / daemon / OlySocket.cpp
1 /**
2  * Copyright (C) ARM Limited 2010-2013. All rights reserved.
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License version 2 as
6  * published by the Free Software Foundation.
7  */
8
9 #include "OlySocket.h"
10
11 #include <stdio.h>
12 #ifdef WIN32
13 #include <Winsock2.h>
14 #include <ws2tcpip.h>
15 #else
16 #include <netinet/in.h>
17 #include <sys/socket.h>
18 #include <unistd.h>
19 #include <netdb.h>
20 #endif
21
22 #include "Logging.h"
23
24 #ifdef WIN32
25 #define CLOSE_SOCKET(x) closesocket(x)
26 #define SHUTDOWN_RX_TX SD_BOTH
27 #define snprintf       _snprintf
28 #else
29 #define CLOSE_SOCKET(x) close(x)
30 #define SHUTDOWN_RX_TX SHUT_RDWR
31 #endif
32
33 OlySocket::OlySocket(int port, bool multiple) {
34 #ifdef WIN32
35   WSADATA wsaData;
36   if (WSAStartup(0x0202, &wsaData) != 0) {
37     logg->logError(__FILE__, __LINE__, "Windows socket initialization failed");
38     handleException();
39   }
40 #endif
41
42   if (multiple) {
43     createServerSocket(port);
44   } else {
45     createSingleServerConnection(port);
46   }
47 }
48
49 OlySocket::OlySocket(int port, char* host) {
50   mFDServer = 0;
51   createClientSocket(host, port);
52 }
53
54 OlySocket::~OlySocket() {
55   if (mSocketID > 0) {
56     CLOSE_SOCKET(mSocketID);
57   }
58 }
59
60 void OlySocket::shutdownConnection() {
61   // Shutdown is primarily used to unblock other threads that are blocking on send/receive functions
62   shutdown(mSocketID, SHUTDOWN_RX_TX);
63 }
64
65 void OlySocket::closeSocket() {
66   // Used for closing an accepted socket but keeping the server socket active
67   if (mSocketID > 0) {
68     CLOSE_SOCKET(mSocketID);
69     mSocketID = -1;
70   }
71 }
72
73 void OlySocket::closeServerSocket() {
74   if (CLOSE_SOCKET(mFDServer) != 0) {
75     logg->logError(__FILE__, __LINE__, "Failed to close server socket.");
76     handleException();
77   }
78   mFDServer = 0;
79 }
80
81 void OlySocket::createClientSocket(char* hostname, int portno) {
82 #ifdef WIN32
83   // TODO: Implement for Windows
84 #else
85   char buf[32];
86   struct addrinfo hints, *res, *res0;
87
88   snprintf(buf, sizeof(buf), "%d", portno);
89   mSocketID = -1;
90   memset((void*)&hints, 0, sizeof(hints));
91   hints.ai_family = PF_UNSPEC;
92   hints.ai_socktype = SOCK_STREAM;
93
94   if (getaddrinfo(hostname, buf, &hints, &res0)) {
95     logg->logError(__FILE__, __LINE__, "Client socket failed to get address info for %s", hostname);
96     handleException();
97   }
98   for (res=res0; res!=NULL; res = res->ai_next) {
99     if ( res->ai_family != PF_INET || res->ai_socktype != SOCK_STREAM ) {
100       continue;
101     }
102     mSocketID = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
103     if (mSocketID < 0) {
104       continue;
105     }
106     if (connect(mSocketID, res->ai_addr, res->ai_addrlen) < 0) {
107       close(mSocketID);
108       mSocketID = -1;
109     }
110     if (mSocketID > 0) {
111       break;
112     }
113   }
114   freeaddrinfo(res0);
115   if (mSocketID <= 0) {
116     logg->logError(__FILE__, __LINE__, "Could not connect to client socket. Ensure ARM Streamline is running.");
117     handleException();
118   }
119 #endif
120 }
121
122 void OlySocket::createSingleServerConnection(int port) {
123   createServerSocket(port);
124
125   mSocketID = acceptConnection();
126   closeServerSocket();
127 }
128
129 void OlySocket::createServerSocket(int port) {
130   int family = AF_INET6;
131
132   // Create socket
133   mFDServer = socket(PF_INET6, SOCK_STREAM, IPPROTO_TCP);
134   if (mFDServer < 0) {
135     family = AF_INET;
136     mFDServer = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
137     if (mFDServer < 0) {
138       logg->logError(__FILE__, __LINE__, "Error creating server socket");
139       handleException();
140     }
141   }
142
143   // Enable address reuse, another solution would be to create the server socket once and only close it when the object exits
144   int on = 1;
145   if (setsockopt(mFDServer, SOL_SOCKET, SO_REUSEADDR, (const char*)&on, sizeof(on)) != 0) {
146     logg->logError(__FILE__, __LINE__, "Setting server socket options failed");
147     handleException();
148   }
149
150   // Create sockaddr_in structure, ensuring non-populated fields are zero
151   struct sockaddr_in6 sockaddr;
152   memset((void*)&sockaddr, 0, sizeof(sockaddr));
153   sockaddr.sin6_family = family;
154   sockaddr.sin6_port = htons(port);
155   sockaddr.sin6_addr = in6addr_any;
156
157   // Bind the socket to an address
158   if (bind(mFDServer, (const struct sockaddr*)&sockaddr, sizeof(sockaddr)) < 0) {
159     logg->logError(__FILE__, __LINE__, "Binding of server socket failed.\nIs an instance already running?");
160     handleException();
161   }
162
163   // Listen for connections on this socket
164   if (listen(mFDServer, 1) < 0) {
165     logg->logError(__FILE__, __LINE__, "Listening of server socket failed");
166     handleException();
167   }
168 }
169
170 // mSocketID is always set to the most recently accepted connection
171 // The user of this class should maintain the different socket connections, e.g. by forking the process
172 int OlySocket::acceptConnection() {
173   if (mFDServer <= 0) {
174     logg->logError(__FILE__, __LINE__, "Attempting multiple connections on a single connection server socket or attempting to accept on a client socket");
175     handleException();
176   }
177
178   // Accept a connection, note that this call blocks until a client connects
179   mSocketID = accept(mFDServer, NULL, NULL);
180   if (mSocketID < 0) {
181     logg->logError(__FILE__, __LINE__, "Socket acceptance failed");
182     handleException();
183   }
184   return mSocketID;
185 }
186
187 void OlySocket::send(char* buffer, int size) {
188   if (size <= 0 || buffer == NULL) {
189     return;
190   }
191
192   while (size > 0) {
193     int n = ::send(mSocketID, buffer, size, 0);
194     if (n < 0) {
195       logg->logError(__FILE__, __LINE__, "Socket send error");
196       handleException();
197     }
198     size -= n;
199     buffer += n;
200   }
201 }
202
203 // Returns the number of bytes received
204 int OlySocket::receive(char* buffer, int size) {
205   if (size <= 0 || buffer == NULL) {
206     return 0;
207   }
208
209   int bytes = recv(mSocketID, buffer, size, 0);
210   if (bytes < 0) {
211     logg->logError(__FILE__, __LINE__, "Socket receive error");
212     handleException();
213   } else if (bytes == 0) {
214     logg->logMessage("Socket disconnected");
215     return -1;
216   }
217   return bytes;
218 }
219
220 // Receive exactly size bytes of data. Note, this function will block until all bytes are received
221 int OlySocket::receiveNBytes(char* buffer, int size) {
222   int bytes = 0;
223   while (size > 0 && buffer != NULL) {
224     bytes = recv(mSocketID, buffer, size, 0);
225     if (bytes < 0) {
226       logg->logError(__FILE__, __LINE__, "Socket receive error");
227       handleException();
228     } else if (bytes == 0) {
229       logg->logMessage("Socket disconnected");
230       return -1;
231     }
232     buffer += bytes;
233     size -= bytes;
234   }
235   return bytes;
236 }
237
238 // Receive data until a carriage return, line feed, or null is encountered, or the buffer fills
239 int OlySocket::receiveString(char* buffer, int size) {
240   int bytes_received = 0;
241   bool found = false;
242
243   if (buffer == 0) {
244     return 0;
245   }
246
247   while (!found && bytes_received < size) {
248     // Receive a single character
249     int bytes = recv(mSocketID, &buffer[bytes_received], 1, 0);
250     if (bytes < 0) {
251       logg->logError(__FILE__, __LINE__, "Socket receive error");
252       handleException();
253     } else if (bytes == 0) {
254       logg->logMessage("Socket disconnected");
255       return -1;
256     }
257
258     // Replace carriage returns and line feeds with zero
259     if (buffer[bytes_received] == '\n' || buffer[bytes_received] == '\r' || buffer[bytes_received] == '\0') {
260       buffer[bytes_received] = '\0';
261       found = true;
262     }
263
264     bytes_received++;
265   }
266
267   return bytes_received;
268 }