summaryrefslogblamecommitdiffstats
path: root/src/main/java/org/openslx/taskmanager/tasks/RemoteExec.java
blob: 3bd0954c85943dfac1ebb553b489eafe68a2700f (plain) (tree)



























































































































































































































































































                                                                                                                                             
package org.openslx.taskmanager.tasks;

import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.openslx.satserver.util.Util;
import org.openslx.taskmanager.api.AbstractTask;

import com.google.gson.annotations.Expose;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;

public class RemoteExec extends AbstractTask
{

	@Expose
	private Client[] clients;

	@Expose
	private String sshkey;

	@Expose
	private int port;

	@Expose
	private String command;

	@Expose
	private int timeoutSeconds = 5;

	private JSch sshClient;

	private Output status = new Output();

	@Override
	protected boolean initTask()
	{
		this.setStatusObject( this.status );

		if ( Util.isEmpty( sshkey ) ) {
			status.addError( "No SSH key given" );
		}
		if ( port < 1 || port > 65535 ) {
			status.addError( "Invalid port number" );
		}
		if ( Util.isEmpty( command ) ) {
			status.addError( "No command given" );
		}
		if ( clients == null || clients.length == 0 ) {
			status.addError( "No clients given" );
		}
		if ( timeoutSeconds < 1 ) {
			status.addError( "Invalid timeout given" );
		}

		for ( Client client : clients ) {
			if ( Util.isEmpty( client.clientip ) ) {
				status.addError( "Client without host/address given!" );
				continue;
			}
			if ( !Util.isEmpty( client.machineuuid ) ) {
				status.result.put( client.machineuuid, new Result() );
			} else {
				status.result.put( client.clientip, new Result() );
			}
			client.timeoutLeft = timeoutSeconds * 1000;
		}

		if ( status.error != null )
			return false;

		JSch.setConfig( "StrictHostKeyChecking", "no" );
		sshClient = new JSch();
		try {
			sshClient.addIdentity( "", sshkey.getBytes(), null, null );
		} catch ( JSchException e ) {
			status.addError( e.toString() );
			return false;
		}

		return true;
	}

	@Override
	protected boolean execute()
	{
		ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length );
		for ( final Client client : clients ) {
			if ( Util.isEmpty( client.clientip ) )
				continue;
			tp.submit( new Runnable() {
				public void run()
				{
					String clientId = !Util.isEmpty( client.machineuuid ) ? client.machineuuid : client.clientip;
					Result clientResult = status.result.get( clientId );
					if ( clientResult == null ) {
						status.addError( "WHAT IN THE F*CK!? No clientResult for " + clientId );
						return;
					}
					Session session = null;
					ChannelExec channel = null;
					clientResult.state = State.CONNECTING;
					try {
						long st = System.currentTimeMillis();
						session = sshClient.getSession(
								Util.isEmpty( client.username ) ? "root" : client.username,
								client.clientip,
								( client.port > 0 && client.port < 65536 ) ? client.port : port
						);
						session.connect( Math.max( client.timeoutLeft / 2, 1100 ) );
						clientResult.state = State.PRE_EXEC;
						channel = (ChannelExec)session.openChannel( "exec" );
						client.timeoutLeft -= System.currentTimeMillis() - st;
						execCommand( channel, client, clientResult );
					} catch ( Exception e ) {
						clientResult.stderr.append( e.toString() + "\n" );
						clientResult.state = State.ERROR;
					} finally {
						if ( session != null ) {
							try {
								channel.disconnect();
							} catch ( Exception e ) {
							}
							try {
								session.disconnect();
							} catch ( Exception e ) {
							}
						}
					}
				}
			} );
		}
		tp.shutdown();

		try {
			tp.awaitTermination( clients.length * 5, TimeUnit.SECONDS );
		} catch ( InterruptedException e ) {
			Thread.currentThread().interrupt();
			return false;
		}

		return true;
	}
	
	private void copy( char[] cbuf, InputStreamReader in, StringBuffer out, StringBuffer other )
	{
		try {
			while ( in.ready() ) {
				int nb = in.read( cbuf );
				if (	nb == -1 )
					break;
				out.append( cbuf, 0, nb );
				if (out.length() + other.length() > 40000 ) {
					int trunc = 40000 - other.length();
					if ( trunc > 0 && trunc < out.length() ) {
						out.setLength( trunc );
					}
					break;
				}
			}
		} catch ( IOException e ) {
			e.printStackTrace();
		}
	}

	private void execCommand( ChannelExec channel, Client client, Result result ) throws JSchException, IOException
	{
		long st = System.currentTimeMillis();
		String cmd = "export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/openslx/sbin:/opt/openslx/bin\n";
		if ( !Util.isEmpty( client.machineuuid ) ) {
			cmd += "if [ \"$(cat /etc/system-uuid)\" != \"" + client.machineuuid + "\" ]; then\n"
					+ "echo 'Machine UUID mismatch' >&2\n"
					+ "exit 42\n"
					+ "fi\n";
		}
		cmd += command;
		channel.setCommand( cmd );
		InputStreamReader stdout = new InputStreamReader( channel.getInputStream(), StandardCharsets.UTF_8 );
		InputStreamReader stderr = new InputStreamReader( channel.getErrStream(), StandardCharsets.UTF_8 );
		result.state = State.EXEC;
		channel.connect( Math.max( client.timeoutLeft - 1000, 500 ) );
		long now = System.currentTimeMillis();
		client.timeoutLeft -= now - st;
		// Read as long as we got time
		char[] cbuf = new char[2000];
		while ( client.timeoutLeft > 0 && !channel.isClosed() ) {
			st = now;
			try {
				Thread.sleep( Math.min( 250, client.timeoutLeft ) );
			} catch ( InterruptedException e ) {
				Thread.currentThread().interrupt();
				break;
			} catch ( Exception ee ) {
				break;
			} finally {
				now = System.currentTimeMillis();
				client.timeoutLeft -= now - st;
			}
			copy( cbuf, stdout, result.stdout, result.stderr );
			copy( cbuf, stderr, result.stderr, result.stdout );
			// Check for reasonable output size
			if ( result.stdout.length() + result.stderr.length() > 40000 ) {
				status.addError( "Truncating output of client " + client.clientip );
				break;
			}
		}
		copy( cbuf, stdout, result.stdout, result.stderr );
		copy( cbuf, stderr, result.stderr, result.stdout );
		if ( channel.isClosed() ) {
			result.state = State.DONE;
		} else {
			result.state = State.TIMEOUT;
		}
		try {
			channel.disconnect();
		} catch ( Exception ee ) {
		}
		if ( channel.isClosed() ) {
			result.exitCode = channel.getExitStatus();
		}
	}

	/**
	 * Output - contains additional status data of this task
	 */
	static class Output
	{
		/** UUID -> Output */
		private Map<String, Result> result = new ConcurrentHashMap<>();

		private String error;

		private synchronized void addError( String e )
		{
			if ( error == null ) {
				error = e + "\n";
			} else {
				error += e + "\n";
			}
		}
	}

	static class Client
	{
		@Expose
		private String machineuuid;
		@Expose
		private String clientip;
		@Expose
		private int port;
		@Expose
		private String username;
		/** How many ms of the given timeout are left */
		private int timeoutLeft;
	}
	
	static enum State
	{
		QUEUED,
		CONNECTING,
		PRE_EXEC,
		EXEC,
		DONE,
		ERROR,
		TIMEOUT
	}

	static class Result
	{
		State state = State.QUEUED;
		StringBuffer stdout = new StringBuffer();
		StringBuffer stderr = new StringBuffer();
		int exitCode = -1;
	}

}