diff options
Diffstat (limited to 'src/input/pvsPrivInputSocket.cpp')
| -rw-r--r-- | src/input/pvsPrivInputSocket.cpp | 196 |
1 files changed, 196 insertions, 0 deletions
diff --git a/src/input/pvsPrivInputSocket.cpp b/src/input/pvsPrivInputSocket.cpp new file mode 100644 index 0000000..2428582 --- /dev/null +++ b/src/input/pvsPrivInputSocket.cpp @@ -0,0 +1,196 @@ +/* + # Copyright (c) 2009 - OpenSLX Project, Computer Center University of Freiburg + # + # This program is free software distributed under the GPL version 2. + # See http://openslx.org/COPYING + # + # If you have any feedback please consult http://openslx.org/feedback and + # send your suggestions, praise, or complaints to feedback@openslx.org + # + # General information about OpenSLX can be found at http://openslx.org/ + # -------------------------------------------------------------------------- + # pvsPrivInputSocket.h: + # - Centralize knowledge of socket address and connection options + # for pvsprivinputd - implementation + # -------------------------------------------------------------------------- + */ + +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> +#include <cerrno> +#include <QtDebug> +#include <QSettings> +#include "pvsPrivInputSocket.h" + +using namespace std; + +#ifndef UNIX_PATH_MAX +# define UNIX_PATH_MAX 108 /* according to unix(7) */ +#endif + +QString pvsPrivInputGetSocketAddress() +{ + QSettings settings(QSettings::NativeFormat, QSettings::SystemScope, "openslx", "pvsprivinputd"); + return settings.value("socketpath", "/tmp/pvsprivinputd.sock").toString(); +} + +int pvsPrivInputMakeClientSocket() +{ + int sock = socket(AF_UNIX, SOCK_DGRAM, 0); + if(sock < 0) + { + qWarning("Could not create a socket: %s", strerror(errno)); + return -1; + } + + QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit(); + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1); + if(connect(sock, reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)) < 0) + { + qWarning("Could not connect to pvsprivinputd at %s: %s", socketPath.constData(), strerror(errno)); + close(sock); + return -1; + } + + return sock; +} + +int pvsPrivInputMakeServerSocket() +{ + int sock = socket(AF_UNIX, SOCK_DGRAM, 0); + if(sock < 0) + { + qCritical("Could not create a socket: %s", strerror(errno)); + return -1; + } + + // Bind to the address: + QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit(); + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1); + if(bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0) + { + qCritical("Could not bind socket to %s", strerror(errno)); + close(sock); + return -1; + } + + // Announce that credentials are requested: + int passcred = 1; + if(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &passcred, sizeof(passcred)) < 0) + { + // We will not operate without credentials. + qCritical("Could not request peer credentials: %s", strerror(errno)); + close(sock); + return -1; + } + +#if 0 /* Only for SOCK_STREAM: */ + // Listen for connections: + if(listen(sock, 1) < 0) + { + qCritical("Could not listen for connections to %s: %s", socketPath.constData(), strerror(errno)); + close(sock); + return -1; + } +#endif + + return sock; +} + +bool pvsPrivInputSendMessage(int sock, void* buf, size_t _len, int* err) +{ + /* + * Portability note: All UNIX-like systems can transmit credentials over UNIX + * sockets, but only Linux does it automagically. + */ + + long len = (long)_len; + + // send(2) does not split messages on a SOCK_DGRAM socket. + int e = send(sock, buf, len, 0); + if(e < 0) + { + qWarning("Failed to send message of length %d over socket %d: %s", (unsigned)len, e, strerror(errno)); + if(err) + *err = errno; + return false; + } + else if(e < len) + { + qWarning("Failed to send a complete message of length %d over socket %d, only %d bytes were sent", (unsigned)len, sock, e); + if(err) + *err = errno; + return false; + } + + return true; +} + +bool pvsPrivInputRecvMessage(int sock, void* buf, size_t& len, + pid_t& pid, uid_t& uid, gid_t& gid, int* err) +{ + struct iovec iov; + struct msghdr msg; + char ctlbuf[1024]; + iov.iov_base = buf; + iov.iov_len = len; + msg.msg_name = 0; + msg.msg_namelen = 0; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = &ctlbuf; + msg.msg_controllen = sizeof(ctlbuf); + msg.msg_flags = 0; + int bytes_read = recvmsg(sock, &msg, 0); + if(bytes_read < 0) + { + qWarning("Could not read from socket: %s", strerror(errno)); + if(err) + *err = errno; + return false; + } + + pid = -1; + uid = -1; + gid = -1; + + struct cmsghdr* cmsg; + for(cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) + { + if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS) + { + struct ucred* creds = reinterpret_cast<struct ucred*>(CMSG_DATA(cmsg)); + pid = creds->pid; + uid = creds->uid; + gid = creds->gid; + break; + } + else if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) + { + // We need to close passed file descriptors. If we don't, we + // have a denial-of-service vulnerability. + int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg)); + unsigned num_fds = cmsg->cmsg_len / sizeof(int); + for(unsigned i = 0; i < num_fds; i++) + { + close(fds[i]); + } + } + } + + if(pid == (pid_t)-1 || uid == (uid_t)-1 || gid == (gid_t)-1) + { + *err = 0; + return false; + } + + return true; +} |
