/*
# Copyright (c) 2009 - OpenSLX Project, Computer Center University of Freiburg
#
# This program is free software distributed under the GPL version 2.
# See http://openslx.org/COPYING
#
# If you have any feedback please consult http://openslx.org/feedback and
# send your suggestions, praise, or complaints to feedback@openslx.org
#
# General information about OpenSLX can be found at http://openslx.org/
# --------------------------------------------------------------------------
# pvsPrivInputSocket.h:
# - Centralize knowledge of socket address and connection options
# for pvsprivinputd - implementation
# --------------------------------------------------------------------------
*/
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <cerrno>
#include <QtDebug>
#include <QSettings>
#include "pvsPrivInputSocket.h"
using namespace std;
#ifndef UNIX_PATH_MAX
# define UNIX_PATH_MAX 108 /* according to unix(7) */
#endif
static QSettings* pvsPrivInputSettings = 0;
QString pvsPrivInputGetSettingsPath()
{
return "/etc/pvsprivinputd.conf";
}
QSettings* pvsPrivInputGetSettings()
{
if(!pvsPrivInputSettings)
{
pvsPrivInputSettings = new QSettings(pvsPrivInputGetSettingsPath(), QSettings::IniFormat);
}
return pvsPrivInputSettings;
}
QSettings* pvsPrivInputReopenSettings()
{
if(pvsPrivInputSettings)
{
delete pvsPrivInputSettings;
pvsPrivInputSettings = 0;
}
return pvsPrivInputGetSettings();
}
QString pvsPrivInputGetSocketAddress()
{
return pvsPrivInputGetSettings()->value("socketpath", "/tmp/pvsprivinputd.sock").toString();
}
bool pvsPrivInputEnableReceiveCredentials(int sock)
{
int passcred = 1;
if(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &passcred, sizeof(passcred)) < 0)
{
return false;
}
else
{
return true;
}
}
int pvsPrivInputMakeClientSocket()
{
int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
if(sock < 0)
{
qWarning("Could not create a socket: %s", strerror(errno));
return -1;
}
QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit();
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1);
if(connect(sock, reinterpret_cast<struct sockaddr*>(&addr),
sizeof(addr)) < 0)
{
qWarning("Could not connect to pvsprivinputd at %s: %s", socketPath.constData(), strerror(errno));
close(sock);
return -1;
}
return sock;
}
int pvsPrivInputMakeServerSocket()
{
int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
if(sock < 0)
{
qCritical("Could not create a socket: %s", strerror(errno));
return -1;
}
// Bind to the address:
QByteArray socketPath = pvsPrivInputGetSocketAddress().toLocal8Bit();
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, socketPath.constData(), UNIX_PATH_MAX - 1);
if(bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0)
{
qCritical("Could not bind socket to %s", strerror(errno));
close(sock);
return -1;
}
// Announce that credentials are requested:
if(!pvsPrivInputEnableReceiveCredentials(sock))
{
// We will not operate without credentials.
qCritical("Could not request peer credentials: %s", strerror(errno));
close(sock);
return -1;
}
#if 0 /* Only for SOCK_STREAM: */
// Listen for connections:
if(listen(sock, 1) < 0)
{
qCritical("Could not listen for connections to %s: %s", socketPath.constData(), strerror(errno));
close(sock);
return -1;
}
#endif
return sock;
}
bool pvsPrivInputSendMessage(int sock, void* buf, size_t _len, int* err)
{
/*
* Portability note: All UNIX-like systems can transmit credentials over UNIX
* sockets, but only Linux does it automagically.
*/
long len = (long)_len;
// send(2) does not split messages on a SOCK_DGRAM socket.
int e = send(sock, buf, len, 0);
if(e < 0)
{
qWarning("Failed to send message of length %d over socket %d: %s", (unsigned)len, e, strerror(errno));
if(err)
*err = errno;
return false;
}
else if(e < len)
{
qWarning("Failed to send a complete message of length %d over socket %d, only %d bytes were sent", (unsigned)len, sock, e);
if(err)
*err = errno;
return false;
}
return true;
}
bool pvsPrivInputRecvMessage(int sock, void* buf, size_t& len,
pid_t& pid, uid_t& uid, gid_t& gid, int* err)
{
struct iovec iov;
struct msghdr msg;
char ctlbuf[1024];
iov.iov_base = buf;
iov.iov_len = len;
msg.msg_name = 0;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = &ctlbuf;
msg.msg_controllen = sizeof(ctlbuf);
msg.msg_flags = 0;
int bytes_read = recvmsg(sock, &msg, 0);
if(bytes_read < 0)
{
qWarning("Could not read from socket: %s", strerror(errno));
if(err)
*err = errno;
return false;
}
pid = -1;
uid = -1;
gid = -1;
struct cmsghdr* cmsg;
for(cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg))
{
if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS)
{
struct ucred* creds = reinterpret_cast<struct ucred*>(CMSG_DATA(cmsg));
pid = creds->pid;
uid = creds->uid;
gid = creds->gid;
break;
}
else if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS)
{
// We need to close passed file descriptors. If we don't, we
// have a denial-of-service vulnerability.
int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
unsigned num_fds = cmsg->cmsg_len / sizeof(int);
for(unsigned i = 0; i < num_fds; i++)
{
close(fds[i]);
}
}
}
if(pid == (pid_t)-1 || uid == (uid_t)-1 || gid == (gid_t)-1)
{
*err = 0;
return false;
}
return true;
}