package org.openslx.taskmanager.tasks;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.openslx.satserver.util.Util;
import org.openslx.taskmanager.api.AbstractTask;
import com.google.gson.annotations.Expose;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
public class RemoteExec extends AbstractTask
{
@Expose
private Client[] clients;
@Expose
private String sshkey;
@Expose
private int port;
@Expose
private String command;
@Expose
private int timeoutSeconds = 5;
private JSch sshClient;
private Output status = new Output();
@Override
protected boolean initTask()
{
this.setStatusObject( this.status );
if ( Util.isEmpty( sshkey ) ) {
status.addError( "No SSH key given" );
}
if ( port < 1 || port > 65535 ) {
status.addError( "Invalid port number" );
}
if ( Util.isEmpty( command ) ) {
status.addError( "No command given" );
}
if ( clients == null || clients.length == 0 ) {
status.addError( "No clients given" );
}
if ( timeoutSeconds < 1 ) {
status.addError( "Invalid timeout given" );
}
for ( Client client : clients ) {
if ( Util.isEmpty( client.clientip ) ) {
status.addError( "Client without host/address given!" );
continue;
}
if ( !Util.isEmpty( client.machineuuid ) ) {
status.result.put( client.machineuuid, new Result() );
} else {
status.result.put( client.clientip, new Result() );
}
client.timeoutLeft = timeoutSeconds * 1000;
}
if ( status.error != null )
return false;
JSch.setConfig( "StrictHostKeyChecking", "no" );
sshClient = new JSch();
try {
sshClient.addIdentity( "", sshkey.getBytes(), null, null );
} catch ( JSchException e ) {
status.addError( e.toString() );
return false;
}
return true;
}
@Override
protected boolean execute()
{
ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length );
for ( final Client client : clients ) {
if ( Util.isEmpty( client.clientip ) )
continue;
tp.submit( new Runnable() {
public void run()
{
String clientId = !Util.isEmpty( client.machineuuid ) ? client.machineuuid : client.clientip;
Result clientResult = status.result.get( clientId );
if ( clientResult == null ) {
status.addError( "WHAT IN THE F*CK!? No clientResult for " + clientId );
return;
}
Session session = null;
ChannelExec channel = null;
clientResult.state = State.CONNECTING;
try {
long st = System.currentTimeMillis();
session = sshClient.getSession(
Util.isEmpty( client.username ) ? "root" : client.username,
client.clientip,
( client.port > 0 && client.port < 65536 ) ? client.port : port
);
session.connect( Math.max( client.timeoutLeft / 2, 1100 ) );
clientResult.state = State.PRE_EXEC;
channel = (ChannelExec)session.openChannel( "exec" );
client.timeoutLeft -= System.currentTimeMillis() - st;
execCommand( channel, client, clientResult );
} catch ( Exception e ) {
clientResult.stderr.append( e.toString() + "\n" );
clientResult.state = State.ERROR;
} finally {
if ( session != null ) {
try {
channel.disconnect();
} catch ( Exception e ) {
}
try {
session.disconnect();
} catch ( Exception e ) {
}
}
}
}
} );
}
tp.shutdown();
try {
tp.awaitTermination( clients.length * 5, TimeUnit.SECONDS );
} catch ( InterruptedException e ) {
Thread.currentThread().interrupt();
return false;
}
return true;
}
private void copy( char[] cbuf, InputStreamReader in, StringBuffer out, StringBuffer other )
{
try {
while ( in.ready() ) {
int nb = in.read( cbuf );
if ( nb == -1 )
break;
out.append( cbuf, 0, nb );
if (out.length() + other.length() > 40000 ) {
int trunc = 40000 - other.length();
if ( trunc > 0 && trunc < out.length() ) {
out.setLength( trunc );
}
break;
}
}
} catch ( IOException e ) {
e.printStackTrace();
}
}
private void execCommand( ChannelExec channel, Client client, Result result ) throws JSchException, IOException
{
long st = System.currentTimeMillis();
String cmd = "export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/openslx/sbin:/opt/openslx/bin\n";
if ( !Util.isEmpty( client.machineuuid ) ) {
cmd += "if [ \"$(cat /etc/system-uuid)\" != \"" + client.machineuuid + "\" ]; then\n"
+ "echo 'Machine UUID mismatch' >&2\n"
+ "exit 42\n"
+ "fi\n";
}
cmd += command;
channel.setCommand( cmd );
InputStreamReader stdout = new InputStreamReader( channel.getInputStream(), StandardCharsets.UTF_8 );
InputStreamReader stderr = new InputStreamReader( channel.getErrStream(), StandardCharsets.UTF_8 );
result.state = State.EXEC;
channel.connect( Math.max( client.timeoutLeft - 1000, 500 ) );
long now = System.currentTimeMillis();
client.timeoutLeft -= now - st;
// Read as long as we got time
char[] cbuf = new char[2000];
while ( client.timeoutLeft > 0 && !channel.isClosed() ) {
st = now;
try {
Thread.sleep( Math.min( 250, client.timeoutLeft ) );
} catch ( InterruptedException e ) {
Thread.currentThread().interrupt();
break;
} catch ( Exception ee ) {
break;
} finally {
now = System.currentTimeMillis();
client.timeoutLeft -= now - st;
}
copy( cbuf, stdout, result.stdout, result.stderr );
copy( cbuf, stderr, result.stderr, result.stdout );
// Check for reasonable output size
if ( result.stdout.length() + result.stderr.length() > 40000 ) {
status.addError( "Truncating output of client " + client.clientip );
break;
}
}
copy( cbuf, stdout, result.stdout, result.stderr );
copy( cbuf, stderr, result.stderr, result.stdout );
if ( channel.isClosed() ) {
result.state = State.DONE;
} else {
result.state = State.TIMEOUT;
}
try {
channel.disconnect();
} catch ( Exception ee ) {
}
if ( channel.isClosed() ) {
result.exitCode = channel.getExitStatus();
}
}
/**
* Output - contains additional status data of this task
*/
static class Output
{
/** UUID -> Output */
private Map<String, Result> result = new ConcurrentHashMap<>();
private String error;
private synchronized void addError( String e )
{
if ( error == null ) {
error = e + "\n";
} else {
error += e + "\n";
}
}
}
static class Client
{
@Expose
private String machineuuid;
@Expose
private String clientip;
@Expose
private int port;
@Expose
private String username;
/** How many ms of the given timeout are left */
private int timeoutLeft;
}
static enum State
{
QUEUED,
CONNECTING,
PRE_EXEC,
EXEC,
DONE,
ERROR,
TIMEOUT
}
static class Result
{
State state = State.QUEUED;
StringBuffer stdout = new StringBuffer();
StringBuffer stderr = new StringBuffer();
int exitCode = -1;
}
}