summaryrefslogtreecommitdiffstats
path: root/src/input/pvsPrivInputSocket.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/input/pvsPrivInputSocket.cpp')
-rw-r--r--src/input/pvsPrivInputSocket.cpp196
1 files changed, 196 insertions, 0 deletions
diff --git a/src/input/pvsPrivInputSocket.cpp b/src/input/pvsPrivInputSocket.cpp
new file mode 100644
index 0000000..2428582
--- /dev/null
+++ b/src/input/pvsPrivInputSocket.cpp
@@ -0,0 +1,196 @@
+/*
+ # Copyright (c) 2009 - OpenSLX Project, Computer Center University of Freiburg
+ #
+ # This program is free software distributed under the GPL version 2.
+ # See http://openslx.org/COPYING
+ #
+ # If you have any feedback please consult http://openslx.org/feedback and
+ # send your suggestions, praise, or complaints to feedback@openslx.org
+ #
+ # General information about OpenSLX can be found at http://openslx.org/
+ # --------------------------------------------------------------------------
+ # pvsPrivInputSocket.h:
+ # - Centralize knowledge of socket address and connection options
+ # for pvsprivinputd - implementation
+ # --------------------------------------------------------------------------
+ */
+
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <cerrno>
+#include <QtDebug>
+#include <QSettings>
+#include "pvsPrivInputSocket.h"
+
+using namespace std;
+
+#ifndef UNIX_PATH_MAX
+# define UNIX_PATH_MAX 108 /* according to unix(7) */
+#endif
+
+QString pvsPrivInputGetSocketAddress()
+{
+ QSettings settings(QSettings::NativeFormat, QSettings::SystemScope, "openslx", "pvsprivinputd");
+ return settings.value("socketpath", "/tmp/pvsprivinputd.sock").toString();
+}
+
+int pvsPrivInputMakeClientSocket()
+{
+ int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
+ if(sock < 0)
+ {
+ qWarning("Could not create a socket: %s", strerror(errno));
+ return -1;
+ }
+
+ QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit();
+ struct sockaddr_un addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sun_family = AF_UNIX;
+ strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1);
+ if(connect(sock, reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)) < 0)
+ {
+ qWarning("Could not connect to pvsprivinputd at %s: %s", socketPath.constData(), strerror(errno));
+ close(sock);
+ return -1;
+ }
+
+ return sock;
+}
+
+int pvsPrivInputMakeServerSocket()
+{
+ int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
+ if(sock < 0)
+ {
+ qCritical("Could not create a socket: %s", strerror(errno));
+ return -1;
+ }
+
+ // Bind to the address:
+ QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit();
+ struct sockaddr_un addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sun_family = AF_UNIX;
+ strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1);
+ if(bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0)
+ {
+ qCritical("Could not bind socket to %s", strerror(errno));
+ close(sock);
+ return -1;
+ }
+
+ // Announce that credentials are requested:
+ int passcred = 1;
+ if(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &passcred, sizeof(passcred)) < 0)
+ {
+ // We will not operate without credentials.
+ qCritical("Could not request peer credentials: %s", strerror(errno));
+ close(sock);
+ return -1;
+ }
+
+#if 0 /* Only for SOCK_STREAM: */
+ // Listen for connections:
+ if(listen(sock, 1) < 0)
+ {
+ qCritical("Could not listen for connections to %s: %s", socketPath.constData(), strerror(errno));
+ close(sock);
+ return -1;
+ }
+#endif
+
+ return sock;
+}
+
+bool pvsPrivInputSendMessage(int sock, void* buf, size_t _len, int* err)
+{
+ /*
+ * Portability note: All UNIX-like systems can transmit credentials over UNIX
+ * sockets, but only Linux does it automagically.
+ */
+
+ long len = (long)_len;
+
+ // send(2) does not split messages on a SOCK_DGRAM socket.
+ int e = send(sock, buf, len, 0);
+ if(e < 0)
+ {
+ qWarning("Failed to send message of length %d over socket %d: %s", (unsigned)len, e, strerror(errno));
+ if(err)
+ *err = errno;
+ return false;
+ }
+ else if(e < len)
+ {
+ qWarning("Failed to send a complete message of length %d over socket %d, only %d bytes were sent", (unsigned)len, sock, e);
+ if(err)
+ *err = errno;
+ return false;
+ }
+
+ return true;
+}
+
+bool pvsPrivInputRecvMessage(int sock, void* buf, size_t& len,
+ pid_t& pid, uid_t& uid, gid_t& gid, int* err)
+{
+ struct iovec iov;
+ struct msghdr msg;
+ char ctlbuf[1024];
+ iov.iov_base = buf;
+ iov.iov_len = len;
+ msg.msg_name = 0;
+ msg.msg_namelen = 0;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = &ctlbuf;
+ msg.msg_controllen = sizeof(ctlbuf);
+ msg.msg_flags = 0;
+ int bytes_read = recvmsg(sock, &msg, 0);
+ if(bytes_read < 0)
+ {
+ qWarning("Could not read from socket: %s", strerror(errno));
+ if(err)
+ *err = errno;
+ return false;
+ }
+
+ pid = -1;
+ uid = -1;
+ gid = -1;
+
+ struct cmsghdr* cmsg;
+ for(cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg))
+ {
+ if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS)
+ {
+ struct ucred* creds = reinterpret_cast<struct ucred*>(CMSG_DATA(cmsg));
+ pid = creds->pid;
+ uid = creds->uid;
+ gid = creds->gid;
+ break;
+ }
+ else if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS)
+ {
+ // We need to close passed file descriptors. If we don't, we
+ // have a denial-of-service vulnerability.
+ int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
+ unsigned num_fds = cmsg->cmsg_len / sizeof(int);
+ for(unsigned i = 0; i < num_fds; i++)
+ {
+ close(fds[i]);
+ }
+ }
+ }
+
+ if(pid == (pid_t)-1 || uid == (uid_t)-1 || gid == (gid_t)-1)
+ {
+ *err = 0;
+ return false;
+ }
+
+ return true;
+}