/* # Copyright (c) 2009 - OpenSLX Project, Computer Center University of Freiburg # # This program is free software distributed under the GPL version 2. # See http://openslx.org/COPYING # # If you have any feedback please consult http://openslx.org/feedback and # send your suggestions, praise, or complaints to feedback@openslx.org # # General information about OpenSLX can be found at http://openslx.org/ # -------------------------------------------------------------------------- # pvsPrivInputSocket.h: # - Centralize knowledge of socket address and connection options # for pvsprivinputd - implementation # -------------------------------------------------------------------------- */ #include #include #include #include #include #include #include "pvsPrivInputSocket.h" using namespace std; #ifndef UNIX_PATH_MAX # define UNIX_PATH_MAX 108 /* according to unix(7) */ #endif static QSettings* pvsPrivInputSettings = 0; QString pvsPrivInputGetSettingsPath() { return "/etc/pvsprivinputd.conf"; } QSettings* pvsPrivInputGetSettings() { if(!pvsPrivInputSettings) { pvsPrivInputSettings = new QSettings(pvsPrivInputGetSettingsPath(), QSettings::IniFormat); } return pvsPrivInputSettings; } QSettings* pvsPrivInputReopenSettings() { if(pvsPrivInputSettings) { delete pvsPrivInputSettings; pvsPrivInputSettings = 0; } return pvsPrivInputGetSettings(); } QString pvsPrivInputGetSocketAddress() { return pvsPrivInputGetSettings()->value("socketpath", "/tmp/pvsprivinputd.sock").toString(); } bool pvsPrivInputEnableReceiveCredentials(int sock) { int passcred = 1; if(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &passcred, sizeof(passcred)) < 0) { return false; } else { return true; } } int pvsPrivInputMakeClientSocket() { int sock = socket(AF_UNIX, SOCK_DGRAM, 0); if(sock < 0) { qWarning("Could not create a socket: %s", strerror(errno)); return -1; } QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit(); struct sockaddr_un addr; memset(&addr, 0, sizeof(addr)); addr.sun_family = AF_UNIX; strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1); if(connect(sock, reinterpret_cast(&addr), sizeof(addr)) < 0) { qWarning("Could not connect to pvsprivinputd at %s: %s", socketPath.constData(), strerror(errno)); close(sock); return -1; } return sock; } int pvsPrivInputMakeServerSocket() { int sock = socket(AF_UNIX, SOCK_DGRAM, 0); if(sock < 0) { qCritical("Could not create a socket: %s", strerror(errno)); return -1; } // Bind to the address: QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit(); struct sockaddr_un addr; memset(&addr, 0, sizeof(addr)); addr.sun_family = AF_UNIX; strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1); if(bind(sock, reinterpret_cast(&addr), sizeof(addr)) < 0) { qCritical("Could not bind socket to %s", strerror(errno)); close(sock); return -1; } // Announce that credentials are requested: if(!pvsPrivInputEnableReceiveCredentials(sock)) { // We will not operate without credentials. qCritical("Could not request peer credentials: %s", strerror(errno)); close(sock); return -1; } #if 0 /* Only for SOCK_STREAM: */ // Listen for connections: if(listen(sock, 1) < 0) { qCritical("Could not listen for connections to %s: %s", socketPath.constData(), strerror(errno)); close(sock); return -1; } #endif return sock; } bool pvsPrivInputSendMessage(int sock, void* buf, size_t _len, int* err) { /* * Portability note: All UNIX-like systems can transmit credentials over UNIX * sockets, but only Linux does it automagically. */ long len = (long)_len; // send(2) does not split messages on a SOCK_DGRAM socket. int e = send(sock, buf, len, 0); if(e < 0) { qWarning("Failed to send message of length %d over socket %d: %s", (unsigned)len, e, strerror(errno)); if(err) *err = errno; return false; } else if(e < len) { qWarning("Failed to send a complete message of length %d over socket %d, only %d bytes were sent", (unsigned)len, sock, e); if(err) *err = errno; return false; } return true; } bool pvsPrivInputRecvMessage(int sock, void* buf, size_t& len, pid_t& pid, uid_t& uid, gid_t& gid, int* err) { struct iovec iov; struct msghdr msg; char ctlbuf[1024]; iov.iov_base = buf; iov.iov_len = len; msg.msg_name = 0; msg.msg_namelen = 0; msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = &ctlbuf; msg.msg_controllen = sizeof(ctlbuf); msg.msg_flags = 0; int bytes_read = recvmsg(sock, &msg, 0); if(bytes_read < 0) { qWarning("Could not read from socket: %s", strerror(errno)); if(err) *err = errno; return false; } pid = -1; uid = -1; gid = -1; struct cmsghdr* cmsg; for(cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) { if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS) { struct ucred* creds = reinterpret_cast(CMSG_DATA(cmsg)); pid = creds->pid; uid = creds->uid; gid = creds->gid; break; } else if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { // We need to close passed file descriptors. If we don't, we // have a denial-of-service vulnerability. int* fds = reinterpret_cast(CMSG_DATA(cmsg)); unsigned num_fds = cmsg->cmsg_len / sizeof(int); for(unsigned i = 0; i < num_fds; i++) { close(fds[i]); } } } if(pid == (pid_t)-1 || uid == (uid_t)-1 || gid == (gid_t)-1) { *err = 0; return false; } return true; }