00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #include "localsocketsite.hh"
00025 #include "accesscontrol.hh"
00026 #include "remotesocketsite.hh"
00027 #include "localmetaobject.hh"
00028 #include "message.hh"
00029 #include "timer.hh"
00030
00031 using namespace VOS;
00032
00033 #include <stdio.h>
00034 #include <fcntl.h>
00035 #include <errno.h>
00036 #include <math.h>
00037
00038 #include <deque>
00039 #include <string>
00040 #include <map>
00041
00042 #ifdef SOCKS_SUPPORT
00043 #undef __P
00044 #if defined (__STDC__) || defined (_AIX) \
00045 || (defined (__mips) && defined (_SYSTYPE_SVR4)) \
00046 || defined(WIN32) || defined(__cplusplus)
00047 # define __P(protos) protos
00048 #else
00049 # define __P(protos) ()
00050 #endif
00051 extern "C" {
00052 int Rbind __P((int, const struct sockaddr *, socklen_t));
00053 int Raccept __P((int, struct sockaddr *, socklen_t *));
00054 void checkmodule(const char*) {}
00055 }
00056 #endif
00057
00058 #ifdef WIN32
00059 #include <windows.h>
00060 #endif
00061
00062
00063 #ifndef SOL_TCP
00064
00065 #define SOL_TCP 6
00066 #endif
00067
00068
00069 #ifdef HAVE_LIBWS2_32
00070 # define SETSOCKOPT_PARAM4_CAST char*
00071 #else
00072 # define SETSOCKOPT_PARAM4_CAST int*
00073 #endif
00074
00075
00076 #if !defined(HAVE_SOCKLEN_T) && !defined(__socklen_t_defined)
00077
00078 typedef int socklen_t;
00079 #define __socklen_t_defined 1
00080 #endif
00081
00082
00083
00084
00085
00086 void LocalSocketSite::init(const string& defaultHostname, unsigned short int port) throw (PortBindingError)
00087 {
00088 srand((unsigned int)time(0));
00089
00090 #ifdef WIN32 // @@@ should go somewhere else
00091 WSADATA WSAData;
00092
00093 WORD wVersionRequested = MAKEWORD(2, 0);
00094 WSAStartup(wVersionRequested, &WSAData);
00095 #endif
00096
00097 struct sockaddr_in sa;
00098 int set=1;
00099
00100 listenport=port;
00101 memset(&sa, 0, sizeof(struct sockaddr_in));
00102 sa.sin_port=htons(listenport);
00103 sa.sin_family=AF_INET;
00104 sa.sin_addr.s_addr = INADDR_ANY;
00105 if((listensocket = (int)socket(AF_INET, SOCK_STREAM, 0)) < 0)
00106 throw PortBindingError(strerror(errno));
00107 #ifndef WIN32
00108 setsockopt(listensocket, SOL_SOCKET, SO_REUSEADDR, (SETSOCKOPT_PARAM4_CAST)&set, sizeof(int));
00109 #endif
00110
00111 #ifdef SOCKS_SUPPORT
00112 if(getenv("VOS_USE_SOCKS")) {
00113 if(Rbind(listensocket, (struct sockaddr *)&sa, sizeof(struct sockaddr_in)) < 0)
00114 throw PortBindingError(strerror(errno));
00115 listen(listensocket, 5);
00116 usingSOCKS = true;
00117 } else {
00118 #endif
00119 if(bind(listensocket, (struct sockaddr *)&sa, sizeof(struct sockaddr_in)) < 0)
00120 throw PortBindingError(strerror(errno));
00121 listen(listensocket, 5);
00122 #ifdef SOCKS_SUPPORT
00123 usingSOCKS = false;
00124 }
00125 #endif
00126
00127 string host;
00128 char c[256];
00129 if(defaultHostname == "") {
00130 gethostname(c, sizeof(c));
00131 host=c;
00132 if(host.find('.') == string::npos) {
00133 LOG("LocalSocketSite", 1, "WARNING: Your host name '" << host << "' does not appear to include a domain name. You will probably not be able to use this application on the Internet! Use the VOS_HOSTNAME environment variable to specify a fully qualified domain name or IP address.");
00134 }
00135 } else host=defaultHostname;
00136
00137 snprintf(c, sizeof(c), "%s:%i", host.c_str(), port);
00138 addHostAlias(c);
00139 if(host != "localhost") {
00140 snprintf(c, sizeof(c), "localhost:%i", port);
00141 addHostAlias(c);
00142 }
00143 snprintf(c, sizeof(c), "%i", port);
00144 url.setHost(host);
00145 url.setPort(c);
00146 addSite(this);
00147 snprintf(c, sizeof(c), "%s:%i", host.c_str(), port);
00148 name = c;
00149
00150 iteratorsUsingPeerSites=0;
00151
00152 selectwait->tv_sec=0;
00153 selectwait->tv_usec=0;
00154
00155 if(LocalSite* s = Site::getDefaultPeer()) {
00156 s->release();
00157 } else {
00158 Site::setDefaultPeer(this);
00159 }
00160 }
00161
00162 LocalSocketSite::LocalSocketSite(SiteAccessControl* ac) throw (PortBindingError)
00163 : VobjectImplementation("", 0, true), LocalMetaObject("", 0, ac),
00164 LocalVobject("", 0, ac), MetaObject(0), alreadyDoingFlushIncoming(false)
00165 {
00166 bool foundPort = false;
00167 int c=VOS_DEFAULT_PORT;
00168 char hostname[256] = "";
00169 if(getenv("VOS_HOSTNAME")) strncpy(hostname, getenv("VOS_HOSTNAME"), sizeof(hostname));
00170 if (*hostname) {
00171 int i;
00172 for(i = 0; hostname[i]; i++) {
00173 if(hostname[i] == ':') {
00174 c = atoi(hostname+i+1);
00175 break;
00176 }
00177 }
00178 hostname[i] = 0;
00179 }
00180 for(; !foundPort && c < 0xFFFF; c++) {
00181 try {
00182 init(hostname, c);
00183 foundPort=true;
00184 } catch(PortBindingError) {
00185 foundPort=false;
00186 }
00187 }
00188
00189 if(c == 0xFFFF) throw PortBindingError("Cannot allocate port");
00190 }
00191
00192 LocalSocketSite::LocalSocketSite(const string& defaultHostname, unsigned short int port, SiteAccessControl* ac)
00193 throw (PortBindingError)
00194 : VobjectImplementation(defaultHostname, 0, true), LocalMetaObject(defaultHostname, 0, ac),
00195 LocalVobject(defaultHostname, 0, ac), MetaObject(0), alreadyDoingFlushIncoming(false)
00196 {
00197 init(defaultHostname, port);
00198 }
00199
00200 LocalSocketSite::LocalSocketSite(const string& defaultHostname, SiteAccessControl* ac) throw (PortBindingError)
00201 : VobjectImplementation(defaultHostname, 0, true), LocalMetaObject(defaultHostname, 0, ac),
00202 LocalVobject(defaultHostname, 0, ac), MetaObject(0), alreadyDoingFlushIncoming(false)
00203 {
00204 bool foundPort = false;
00205 int c;
00206 for(c=VOS_DEFAULT_PORT; !foundPort && c < 0xFFFF; c++) {
00207 try {
00208 init(defaultHostname, c);
00209 foundPort=true;
00210 } catch(PortBindingError) {
00211 foundPort=false;
00212 }
00213 }
00214
00215 if(c == 0xFFFF) throw PortBindingError("Cannot allocate port");
00216 }
00217
00218 LocalSocketSite::LocalSocketSite(unsigned short int port, SiteAccessControl* ac) throw (PortBindingError)
00219 : VobjectImplementation("", 0, true), LocalVobject("", 0, ac),
00220 LocalMetaObject("", 0, ac), MetaObject(0), alreadyDoingFlushIncoming(false)
00221 {
00222 char *hostname = getenv("VOS_HOSTNAME");
00223 if(hostname == NULL)
00224 hostname = "";
00225 init(hostname, port);
00226 }
00227
00228 LocalSocketSite::~LocalSocketSite()
00229 {
00230 #ifdef WIN32
00231 WSACleanup();
00232 #endif
00233 }
00234
00235 void LocalSocketSite::acceptConnectionRequests()
00236 {
00237 fd_set f;
00238 struct timeval tv = {0, 0};
00239 struct sockaddr_in sa;
00240
00241 socklen_t len=sizeof(sa);
00242
00243 FD_ZERO(&f);
00244 FD_SET(listensocket, &f);
00245 while(select(listensocket+1, &f, 0, 0, &tv) > 0
00246 && FD_ISSET(listensocket, &f)) {
00247 int newsock;
00248 #ifdef SOCKS_SUPPORT
00249 if(usingSOCKS) newsock=Raccept(listensocket, (sockaddr*)&sa, &len);
00250 else
00251 #endif
00252 newsock = (int)accept(listensocket, (sockaddr*)&sa, &len);
00253 int x = 1;
00254 setsockopt(newsock, SOL_TCP, TCP_NODELAY, (char*)&x, sizeof(int));
00255 #ifdef WIN32
00256 unsigned long one = 1;
00257 ioctlsocket(newsock, FIONBIO, &one);
00258 #else
00259 unsigned int one = fcntl(newsock, F_GETFL);
00260 one = one | O_NONBLOCK;
00261 fcntl(newsock, F_SETFL, one);
00262 #endif
00263 LOG("localsite", 3, "got new connection, fd " << newsock);
00264 if(newsock > -1) {
00265 allOpenSockets.insert(newsock);
00266 vRef<RemoteSocketSite> rs = new RemoteSocketSite(newsock, &sa);
00267
00268 if(getenv("VOS_HOSTNAME") == 0) {
00269 unsigned char ipaddr[4];
00270 string host = detectHostname(rs->getReadingFD(), ipaddr);
00271 URL u(getURL());
00272 u.setHost(host);
00273 addHostAlias(u.getHostAndPort());
00274 if(!(ipaddr[0] == 10 || ipaddr[0] == 127 || (ipaddr[0] == 192 && ipaddr[1] == 168))) {
00275 setPrimaryHostname(host);
00276 }
00277 }
00278 try {
00279 LOG("refcount", 5, "count on new rs is " << rs->getCount());
00280
00281
00282
00283 LOG("localsocketsite", 4, "Beginning site peering of inbound connection");
00284 doSitePeering(this, &rs, false, true);
00285 LOG("localsocketsite", 4, "Done site peering of inbound connection");
00286 LOG("refcount", 5, "count on rs post-peering is " << rs->getCount());
00287
00288
00289
00290
00291
00292
00293
00294 vRef<Message> m = new Message();
00295 m->setType("message");
00296 m->setMethod("core:get-types");
00297 m->setTo(rs->getURL().getString());
00298 m->setFrom(getURL().getString());
00299 rs->sendMessage(&m);
00300 } catch(RemoteError x) {
00301 LOG("localsite", 3, "Error peering with remote site: " << x.what());
00302 }
00303 LOG("refcount", 5, "count on rs getTypes n stuff is " << rs->getCount());
00304 } else {
00305 LOG("localsite", 3, "could not accept connection request: " << strerror(errno));
00306 break;
00307 }
00308 FD_ZERO(&f);
00309 FD_SET(listensocket, &f);
00310 }
00311 }
00312
00313 int LocalSocketSite::getFDset(fd_set* readset, fd_set* writeset, fd_set* errorset)
00314 {
00315 FD_SET(listensocket, readset);
00316 int max=listensocket;
00317
00318 for(set<int>::iterator i = allOpenSockets.begin(); i != allOpenSockets.end(); i++) {
00319 if(*i > -1) {
00320 FD_SET(*i, readset);
00321 if(*i > max) max=*i;
00322 }
00323 }
00324
00325
00326 for(map<int, pair<int, int> >::iterator i = extra_fds.begin(); i != extra_fds.end(); i++) {
00327 if((*i).first > max) max = (*i).first;
00328 if((*i).second.first & SELECTREAD) FD_SET((*i).first, readset);
00329 if((*i).second.first & SELECTWRITE) FD_SET((*i).first, writeset);
00330 if((*i).second.first & SELECTERROR) FD_SET((*i).first, errorset);
00331 }
00332 return max;
00333 }
00334
00335 struct timeval* LocalSocketSite::calculateTimeout()
00336 {
00337 struct timeval *tv;
00338
00339 if(selectwait) {
00340 tv = new struct timeval();
00341 *tv = *selectwait;
00342 } else tv=0;
00343
00344 if(!siteMessageQueue.empty() || !callbacks.empty()) {
00345 LOG("localsocketsite", 5, "smq " << siteMessageQueue.empty() << " cb " << callbacks.empty());
00346
00347 double now = getTimer();
00348
00349 double nextEvent;
00350 if(siteMessageQueue.empty()) {
00351 nextEvent = (*callbacks.begin()).first;
00352 } else if(callbacks.empty()) {
00353 nextEvent = (*siteMessageQueue.begin()).first;
00354 } else {
00355 if((*siteMessageQueue.begin()).first < (*callbacks.begin()).first)
00356 nextEvent = (*siteMessageQueue.begin()).first;
00357 else
00358 nextEvent = (*callbacks.begin()).first;
00359 }
00360
00361
00362
00363
00364 nextEvent -= now;
00365
00366 if(nextEvent <= 0) {
00367
00368
00369 if(! tv) tv = new struct timeval();
00370 tv->tv_sec = 0;
00371 tv->tv_usec = 0;
00372 } else {
00373 if(tv) {
00374
00375
00376
00377
00378 double maxWait = tv->tv_sec + (tv->tv_usec / 1000000.0);
00379 if(nextEvent < maxWait) {
00380 tv->tv_sec = (long)floor(nextEvent);
00381 tv->tv_usec = (long)((nextEvent - floor(nextEvent)) * 1000000.0);
00382 }
00383 } else {
00384
00385 tv = new struct timeval();
00386 tv->tv_sec = (long)floor(nextEvent);
00387 tv->tv_usec = (long)((nextEvent - floor(nextEvent)) * 1000000.0);
00388 }
00389 }
00390 }
00391
00392 return tv;
00393 }
00394
00395 void LocalSocketSite::flushIncomingBuffers()
00396 {
00397 bool more;
00398
00399
00400 for(list< pair<string, RemoteSite*> >::iterator i = needSpoofIDreply.begin();
00401 i != needSpoofIDreply.end(); i++)
00402 {
00403 string spid = getAntiSpoofIDMapping((*i).first);
00404 if(spid != "") {
00405 vRef<Message> reply = new Message();
00406 reply->setType("message");
00407 reply->setFrom("");
00408 reply->setTo("");
00409 reply->setMethod("core:anti-spoof-reply");
00410 reply->insertField(-1, "poke", spid);
00411 (*i).second->sendMessage(&reply);
00412
00413 (*i).second->release();
00414
00415 list< pair<string, RemoteSite*> >::iterator n = i;
00416 i++;
00417 needSpoofIDreply.erase(n);
00418 }
00419 }
00420
00421
00422
00423 do {
00424 more=false;
00425
00426 struct timeval *tv = 0;
00427 fd_set readset, writeset, errorset;
00428
00429 do {
00430 FD_ZERO(&readset);
00431 FD_ZERO(&writeset);
00432 FD_ZERO(&errorset);
00433
00434 int max = getFDset(&readset, &writeset, &errorset);
00435
00436 tv = calculateTimeout();
00437
00438 errno = 0;
00439 select(max+1, &readset, &writeset, &errorset, tv);
00440
00441 if(tv) delete tv;
00442 } while(errno == EINTR);
00443
00444 for(map<int, pair<int, int> >::iterator i = extra_fds.begin(); i != extra_fds.end(); i++) {
00445 (*i).second.second = 0;
00446 if(FD_ISSET((*i).first, &readset)) (*i).second.second |= SELECTREAD;
00447 if(FD_ISSET((*i).first, &writeset)) (*i).second.second |= SELECTWRITE;
00448 if(FD_ISSET((*i).first, &errorset)) (*i).second.second |= SELECTERROR;
00449 }
00450
00451 if(FD_ISSET(listensocket, &readset)) acceptConnectionRequests();
00452
00453 iteratorsUsingPeerSites++;
00454 for(map<RemoteSite*, SiteTableEntry*>::iterator i = peerSites.begin(); i != peerSites.end(); i++) {
00455 if(!((*i).first && (*i).second)) continue;
00456 if(RemoteSocketSite* rss = dynamic_cast<RemoteSocketSite*>((*i).first)) {
00457 LOG("localsocketsite", 5, "waaagh, rss->getReadingFD is " << rss->getReadingFD() << " and "
00458 << (FD_ISSET(rss->getReadingFD(), &readset) != 0));
00459
00460 if(rss->getReadingFD() > -1 && FD_ISSET(rss->getReadingFD(), &readset))
00461 rss->flushIncomingBuffers();
00462
00463 if(rss->getWritingFD() > -1 && FD_ISSET(rss->getWritingFD(), &writeset))
00464 rss->flushOutgoingBuffers();
00465 } else (*i).first->flushIncomingBuffers();
00466 }
00467 if(--iteratorsUsingPeerSites == 0) {
00468 boost::mutex::scoped_lock lock(peerSitesBuffer_mutex);
00469
00470 while(peerSitesBuffer_remove.size() > 0) {
00471 removeRemotePeer(peerSitesBuffer_remove.front());
00472 peerSitesBuffer_remove.pop_front();
00473 }
00474 deque<RemoteSite*>::iterator i = peerSitesBuffer_rs.begin();
00475 deque<struct SiteTableEntry*>::iterator n = peerSitesBuffer_st.begin();
00476 while(i != peerSitesBuffer_rs.end()) {
00477 peerSites[*i] = *n;
00478 i++;
00479 n++;
00480 }
00481 peerSitesBuffer_rs.clear();
00482 peerSitesBuffer_st.clear();
00483 }
00484 } while(more);
00485
00486
00487
00488 runSchedule();
00489 }
00490
00491 const int LocalSocketSite::SELECTREAD = 0x1;
00492 const int LocalSocketSite::SELECTWRITE = 0x2;
00493 const int LocalSocketSite::SELECTERROR = 0x4;
00494
00495 void LocalSocketSite::addFDtoSelect(int fd, int which)
00496 {
00497 extra_fds[fd] = pair<int, int>(which, 0);
00498 }
00499
00500 void LocalSocketSite::removeFDfromSelect(int fd, int which)
00501 {
00502 extra_fds[fd].first &= ~which;
00503 if(extra_fds[fd].first == 0) extra_fds.erase(fd);
00504 }
00505
00506 int LocalSocketSite::lastSelectResultForFD(int fd)
00507 {
00508 return extra_fds[fd].second;
00509 }
00510
00511 SSL_CTX* LocalSocketSite::getSSLContext(const string& protocol)
00512 {
00513 if(protocol == "TLSv1") {
00514 return TLSv1context;
00515 } else if(protocol == "SSLv3") {
00516 return SSLv3context;
00517 } else return 0;
00518 }
00519
00520 void LocalSocketSite::sendMessage(Message* m)
00521 {
00522 LocalSite::sendMessage(m);
00523
00524 if(m->getMethod() == "core:protocol-switch") {
00525 #ifdef SSL_SUPPORT
00526 try {
00527 pREF(RemoteSocketSite*, rs, meta_cast<RemoteSocketSite*>(m->getSourceSite()),
00528 if(rs->getLocalPeer() == this) {
00529 SSL_CTX* ctx = getSSLContext(m->getField("protocol").value);
00530 pREF(Message*, reply, new Message(),
00531 rREF(Vobject&, from, findObject(m->getFrom()),
00532 initReply(this, reply, m, "core:protocol-switch-reply");
00533 if(ctx) {
00534 reply->insertField(-1, "result", "ready");
00535 from.sendMessage(reply);
00536 rs->switchProtocol(ctx);
00537 } else {
00538 reply->insertField(-1, "result", "unsupported protocol "
00539 + m->getField("protocol").value + string("; supports TLSv1, SSLv3"));
00540 from.sendMessage(reply);
00541 }
00542 );
00543 );
00544 }
00545 );
00546 } catch(Message::NoSuchFieldError) {
00547 } catch(bad_cast) {
00548 } catch(NoSuchObjectError) {
00549 }
00550 #else
00551 pREF(Message*, reply, new Message(),
00552 rREF(Vobject&, from, findObject(m->getFrom()),
00553 initReply(this, reply, m, "core:protocol-switch-reply");
00554 reply->insertField(-1, "result", "this site does not support secure protocols");
00555 from.sendMessage(reply);
00556 );
00557 );
00558 #endif
00559 }
00560 }
00561
00562 int LocalSocketSite::pemPasswordCallback(char *buf, int size, int rwflag, void *password)
00563 {
00564 string* pw = (string*)password;
00565 strncpy(buf, pw->c_str(), (pw->size() < size ? pw->size() : size));
00566 buf[size - 1] = '\0';
00567 return (int)strlen(buf);
00568 }
00569
00570 int LocalSocketSite::sslVerifyCallback(int preverify_ok, X509_STORE_CTX *ctx)
00571 {
00572 return 1;
00573 }
00574
00575 void LocalSocketSite::setupSSL(const string& certificateFile, const string& privateKeyFile, const string& password)
00576 {
00577 #ifdef SSL_SUPPORT
00578 static bool needInitSSL = true;
00579
00580 if(needInitSSL) {
00581 SSL_load_error_strings();
00582 SSL_library_init();
00583 srand(time(NULL));
00584 int rnd = rand();
00585 RAND_add(&rnd, 4, 0);
00586
00587 needInitSSL = false;
00588 }
00589
00590 static string pw = password;
00591
00592 TLSv1context = SSL_CTX_new(TLSv1_method());
00593 if(! TLSv1context) {
00594 throw SSLError("Could not set up a TLSv1 context");
00595 }
00596
00597 SSLv3context = SSL_CTX_new(SSLv3_method());
00598 if(! SSLv3context) {
00599 throw SSLError("Could not set up a SSLv3 context");
00600 }
00601
00602 SSL_CTX_set_default_passwd_cb(TLSv1context, pemPasswordCallback);
00603 SSL_CTX_set_default_passwd_cb(SSLv3context, pemPasswordCallback);
00604
00605 SSL_CTX_set_default_passwd_cb_userdata(TLSv1context, &pw);
00606 SSL_CTX_set_default_passwd_cb_userdata(SSLv3context, &pw);
00607
00608 int ret = SSL_CTX_use_RSAPrivateKey_file (TLSv1context, privateKeyFile.c_str(), SSL_FILETYPE_PEM);
00609 if(!ret) {
00610 throw SSLError("Bad private key (cannot read or decrypt file)");
00611 }
00612 SSL_CTX_use_RSAPrivateKey_file (SSLv3context, privateKeyFile.c_str(), SSL_FILETYPE_PEM);
00613
00614 ret = SSL_CTX_use_certificate_file (TLSv1context, certificateFile.c_str(), SSL_FILETYPE_PEM);
00615 if(!ret) {
00616 throw SSLError("Bad certificate key (cannot read file)");
00617 }
00618 SSL_CTX_use_certificate_file (SSLv3context, certificateFile.c_str(), SSL_FILETYPE_PEM);
00619
00620 SSL_CTX_set_verify(TLSv1context, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, sslVerifyCallback);
00621 SSL_CTX_set_verify(SSLv3context, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, sslVerifyCallback);
00622 #else
00623 throw RemoteError("libvos was not compiled with SSL/TLS support");
00624 #endif
00625 }