summaryrefslogblamecommitdiffstats
path: root/src/input/pvsPrivInputSocket.cpp
blob: df5dff5577ccf5e619d1806503d6f80429f4cb44 (plain) (tree)






























                                                                               

























                                                                                                          

                                      


                                                                                                    










                                                                                      

















































                                                                                                                  
                                                       












































































































                                                                                                                                           
/*
 # Copyright (c) 2009 - OpenSLX Project, Computer Center University of Freiburg
 #
 # This program is free software distributed under the GPL version 2.
 # See http://openslx.org/COPYING
 #
 # If you have any feedback please consult http://openslx.org/feedback and
 # send your suggestions, praise, or complaints to feedback@openslx.org
 #
 # General information about OpenSLX can be found at http://openslx.org/
 # --------------------------------------------------------------------------
 # pvsPrivInputSocket.h:
 #  - Centralize knowledge of socket address and connection options
 #    for pvsprivinputd - implementation
 # --------------------------------------------------------------------------
 */

#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <cerrno>
#include <QtDebug>
#include <QSettings>
#include "pvsPrivInputSocket.h"

using namespace std;

#ifndef UNIX_PATH_MAX
# define UNIX_PATH_MAX 108 /* according to unix(7) */
#endif

static QSettings* pvsPrivInputSettings = 0;

QString pvsPrivInputGetSettingsPath()
{
	return "/etc/pvsprivinputd.conf";
}

QSettings* pvsPrivInputGetSettings()
{
	if(!pvsPrivInputSettings)
	{
		pvsPrivInputSettings = new QSettings(pvsPrivInputGetSettingsPath(), QSettings::IniFormat);
	}
	return pvsPrivInputSettings;
}

QSettings* pvsPrivInputReopenSettings()
{
	if(pvsPrivInputSettings)
	{
		delete pvsPrivInputSettings;
		pvsPrivInputSettings = 0;
	}
	return pvsPrivInputGetSettings();
}

QString pvsPrivInputGetSocketAddress()
{
	return pvsPrivInputGetSettings()->value("socketpath", "/tmp/pvsprivinputd.sock").toString();
}

bool pvsPrivInputEnableReceiveCredentials(int sock)
{
	int passcred = 1;
	if(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &passcred, sizeof(passcred)) < 0)
	{
		return false;
	}
	else
	{
		return true;
	}
}

int pvsPrivInputMakeClientSocket()
{
	int sock =  socket(AF_UNIX, SOCK_DGRAM, 0);
	if(sock < 0)
	{
		qWarning("Could not create a socket: %s", strerror(errno));
		return -1;
	}

	QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit();
	struct sockaddr_un addr;
	memset(&addr, 0, sizeof(addr));
	addr.sun_family = AF_UNIX;
	strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1);
	if(connect(sock, reinterpret_cast<struct sockaddr*>(&addr),
				sizeof(addr)) < 0)
	{
		qWarning("Could not connect to pvsprivinputd at %s: %s", socketPath.constData(), strerror(errno));
		close(sock);
		return -1;
	}

	return sock;
}

int pvsPrivInputMakeServerSocket()
{
	int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
	if(sock < 0)
	{
		qCritical("Could not create a socket: %s", strerror(errno));
		return -1;
	}

	// Bind to the address:
	QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit();
	struct sockaddr_un addr;
	memset(&addr, 0, sizeof(addr));
	addr.sun_family = AF_UNIX;
	strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1);
	if(bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0)
	{
		qCritical("Could not bind socket to %s", strerror(errno));
		close(sock);
		return -1;
	}

	// Announce that credentials are requested:
	if(!pvsPrivInputEnableReceiveCredentials(sock))
	{
		// We will not operate without credentials.
		qCritical("Could not request peer credentials: %s", strerror(errno));
		close(sock);
		return -1;
	}

#if 0 /* Only for SOCK_STREAM: */
	// Listen for connections:
	if(listen(sock, 1) < 0)
	{
		qCritical("Could not listen for connections to %s: %s", socketPath.constData(), strerror(errno));
		close(sock);
		return -1;
	}
#endif

	return sock;
}

bool pvsPrivInputSendMessage(int sock, void* buf, size_t _len, int* err)
{
	/*
	 * Portability note: All UNIX-like systems can transmit credentials over UNIX
	 * sockets, but only Linux does it automagically.
	 */

	long len = (long)_len;

	// send(2) does not split messages on a SOCK_DGRAM socket.
	int e = send(sock, buf, len, 0);
	if(e < 0)
	{
		qWarning("Failed to send message of length %d over socket %d: %s", (unsigned)len, e, strerror(errno));
		if(err)
			*err = errno;
		return false;
	}
	else if(e < len)
	{
		qWarning("Failed to send a complete message of length %d over socket %d, only %d bytes were sent", (unsigned)len, sock, e);
		if(err)
			*err = errno;
		return false;
	}

	return true;
}

bool pvsPrivInputRecvMessage(int sock, void* buf, size_t& len,
		pid_t& pid, uid_t& uid, gid_t& gid, int* err)
{
	struct iovec iov;
	struct msghdr msg;
	char ctlbuf[1024];
	iov.iov_base = buf;
	iov.iov_len = len;
	msg.msg_name = 0;
	msg.msg_namelen = 0;
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;
	msg.msg_control = &ctlbuf;
	msg.msg_controllen = sizeof(ctlbuf);
	msg.msg_flags = 0;
	int bytes_read = recvmsg(sock, &msg, 0);
	if(bytes_read < 0)
	{
		qWarning("Could not read from socket: %s", strerror(errno));
		if(err)
			*err = errno;
		return false;
	}

	pid = -1;
	uid = -1;
	gid = -1;

	struct cmsghdr* cmsg;
	for(cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg))
	{
		if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS)
		{
			struct ucred* creds = reinterpret_cast<struct ucred*>(CMSG_DATA(cmsg));
			pid = creds->pid;
			uid = creds->uid;
			gid = creds->gid;
			break;
		}
		else if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS)
		{
			// We need to close passed file descriptors. If we don't, we
			// have a denial-of-service vulnerability.
			int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
			unsigned num_fds = cmsg->cmsg_len / sizeof(int);
			for(unsigned i = 0; i < num_fds; i++)
			{
				close(fds[i]);
			}
		}
	}

	if(pid == (pid_t)-1 || uid == (uid_t)-1 || gid == (gid_t)-1)
	{
		*err = 0;
		return false;
	}

	return true;
}