package org.openslx.taskmanager.tasks;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.openslx.satserver.util.Util;
import org.openslx.taskmanager.api.AbstractTask;
import com.google.gson.annotations.Expose;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
public class RemoteExec extends AbstractTask
{
private static final Logger LOGGER = LogManager.getLogger( RemoteExec.class );
protected final static int MAX_OUTPUT_PER_CLIENT = 400000;
@Expose
private Client[] clients;
@Expose
private String sshkey;
@Expose
private int port;
@Expose
private String command;
@Expose
private int timeoutSeconds = 5;
private JSch sshClient;
private Output status = new Output();
public RemoteExec()
{
}
public RemoteExec(Client[] clients, String sshkey, int port, String command, int timeoutSeconds)
{
this.clients = clients;
this.sshkey = sshkey;
this.port = port;
this.command = command;
this.timeoutSeconds = timeoutSeconds;
initTask();
}
@Override
protected boolean initTask()
{
this.setStatusObject( this.status );
if ( Util.isEmpty( sshkey ) ) {
status.addError( "No SSH key given" );
}
if ( port < 1 || port > 65535 ) {
status.addError( "Invalid port number" );
}
if ( Util.isEmpty( command ) ) {
status.addError( "No command given" );
}
if ( clients == null || clients.length == 0 ) {
status.addError( "No clients given" );
}
if ( timeoutSeconds < 1 ) {
status.addError( "Invalid timeout given" );
}
for ( Client client : clients ) {
if ( Util.isEmpty( client.clientip ) ) {
status.addError( "Client without host/address given!" );
continue;
}
if ( !Util.isEmpty( client.machineuuid ) ) {
status.result.put( client.machineuuid, new Result() );
} else {
status.result.put( client.clientip, new Result() );
}
client.timeoutLeft = timeoutSeconds * 1000;
}
if ( status.error != null )
return false;
JSch.setConfig( "StrictHostKeyChecking", "no" );
sshClient = new JSch();
try {
sshClient.addIdentity( "", sshkey.getBytes(), null, null );
} catch ( JSchException e ) {
status.addError( e.toString() );
return false;
}
return true;
}
@Override
protected boolean execute()
{
ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length );
for ( final Client client : clients ) {
if ( Util.isEmpty( client.clientip ) )
continue;
tp.submit( new Runnable() {
public void run()
{
String clientId = !Util.isEmpty( client.machineuuid ) ? client.machineuuid : client.clientip;
Result clientResult = status.result.get( clientId );
if ( clientResult == null ) {
status.addError( "WHAT IN THE F*CK!? No clientResult for " + clientId );
return;
}
Session session = null;
ChannelExec channel = null;
clientResult.state = State.CONNECTING;
try {
long st = System.currentTimeMillis();
session = sshClient.getSession(
Util.isEmpty( client.username ) ? "root" : client.username,
client.clientip,
( client.port > 0 && client.port < 65536 ) ? client.port : port
);
session.connect( Math.max( client.timeoutLeft / 2, 1100 ) );
clientResult.state = State.PRE_EXEC;
channel = (ChannelExec)session.openChannel( "exec" );
client.timeoutLeft -= System.currentTimeMillis() - st;
execCommand( channel, client, clientResult );
} catch ( Exception e ) {
clientResult.stderr.append( e.toString() + "\n" );
clientResult.state = State.ERROR;
} finally {
if ( session != null ) {
try {
channel.disconnect();
} catch ( Exception e ) {
}
try {
session.disconnect();
} catch ( Exception e ) {
}
}
}
}
} );
}
tp.shutdown();
try {
tp.awaitTermination( clients.length * this.timeoutSeconds + 5, TimeUnit.SECONDS );
} catch ( InterruptedException e ) {
Thread.currentThread().interrupt();
return false;
}
return true;
}
public Output getStatusObject()
{
return this.status;
}
private void execCommand( ChannelExec channel, Client client, Result result ) throws JSchException, IOException
{
long st = System.currentTimeMillis();
String cmd = "export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/openslx/sbin:/opt/openslx/bin\n";
if ( !Util.isEmpty( client.machineuuid ) ) {
cmd += "if [ \"$(cat /etc/system-uuid)\" != \"" + client.machineuuid + "\" ]; then\n"
+ "echo 'Machine UUID mismatch' >&2\n"
+ "exit 42\n"
+ "fi\n";
}
cmd += command;
channel.setCommand( cmd );
LimitedBufferStream lbsOut = new LimitedBufferStream( result.stdout );
LimitedBufferStream lbsErr = new LimitedBufferStream( result.stderr );
channel.setOutputStream( lbsOut );
channel.setErrStream( lbsErr );
result.state = State.EXEC;
channel.connect( Math.max( client.timeoutLeft - 1000, 500 ) );
long now = System.currentTimeMillis();
client.timeoutLeft -= now - st;
// Read as long as we got time
while ( client.timeoutLeft > 0 && !channel.isClosed() ) {
st = now;
try {
Thread.sleep( Math.min( 250, client.timeoutLeft ) );
} catch ( InterruptedException e ) {
Thread.currentThread().interrupt();
break;
} catch ( Exception ee ) {
LOGGER.warn( "Cannot sleep", ee );
break;
} finally {
now = System.currentTimeMillis();
client.timeoutLeft -= now - st;
}
}
if ( channel.isClosed() ) {
result.state = State.DONE;
} else {
result.state = State.TIMEOUT;
}
try {
channel.disconnect();
} catch ( Exception ee ) {
}
if ( channel.isClosed() ) {
result.exitCode = channel.getExitStatus();
}
}
static class LimitedBufferStream extends OutputStream
{
private final StringBuffer sb;
private final CharBuffer cb = CharBuffer.allocate( 200 );
private final ByteBuffer bb = ByteBuffer.allocate( 200 );
private final CharsetDecoder decoder;
public LimitedBufferStream(StringBuffer sb)
{
this.sb = sb;
decoder = StandardCharsets.UTF_8.newDecoder();
decoder.onMalformedInput( CodingErrorAction.REPLACE );
decoder.onUnmappableCharacter( CodingErrorAction.REPLACE );
decoder.reset();
}
@Override
public void write( int b ) throws IOException
{
if ( sb.length() >= RemoteExec.MAX_OUTPUT_PER_CLIENT )
return;
bb.put( (byte) ( b & 0xff ) );
decode();
}
@Override
public void write( byte[] b, int off, int len ) throws IOException
{
if ( sb.length() >= RemoteExec.MAX_OUTPUT_PER_CLIENT )
return;
while ( len > 0 ) {
int nlen = Math.min( bb.remaining(), len );
if ( nlen <= 0 )
throw new RuntimeException( "Empty buffer" );
bb.put( b, off, nlen );
off += nlen;
len -= nlen;
decode();
}
}
private void decode()
{
if ( bb.position() == 0 )
return;
( (Buffer)bb ).limit( bb.position() );
( (Buffer)bb ).position( 0 );
try {
decoder.decode( bb, cb, false );
} catch ( Throwable t ) {
LOGGER.warn( "Cannot convert data to UTF8", t );
}
bb.compact();
( (Buffer)cb ).limit( cb.position() );
( (Buffer)cb ).position( 0 );
sb.append( cb.toString() );
( (Buffer)cb ).clear();
}
}
/**
* Output - contains additional status data of this task
*/
static class Output
{
/** UUID -> Output */
Map<String, Result> result = new ConcurrentHashMap<>();
String error;
private synchronized void addError( String e )
{
if ( error == null ) {
error = e + "\n";
} else {
error += e + "\n";
}
}
}
static class Client
{
@Expose
private String machineuuid;
@Expose
private String clientip;
@Expose
private int port;
@Expose
private String username;
/** How many ms of the given timeout are left */
private int timeoutLeft;
public Client()
{
}
public Client(String machineuuid, String clientip, int port, String username)
{
this.machineuuid = machineuuid;
this.clientip = clientip;
this.port = port;
this.username = username;
}
}
static enum State
{
QUEUED,
CONNECTING,
PRE_EXEC,
EXEC,
DONE,
ERROR,
TIMEOUT
}
static class Result
{
State state = State.QUEUED;
StringBuffer stdout = new StringBuffer();
StringBuffer stderr = new StringBuffer();
int exitCode = -1;
}
}