summaryrefslogblamecommitdiffstats
path: root/src/main/java/org/openslx/taskmanager/tasks/RemoteExec.java
blob: daec3297bf76974b18a79f047d9ec616c009fda8 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11



                                      






                                          






                                              

                                           










                                                
        
                                                                                      

                                                                  
























































































































                                                                                                                                     
                                                                                                          






                                                           












                                                                                                                                             



                                                                                      




                                                                              







                                                                                    
                                                                  




                                                                 
                 













                                                                  





























































                                                                                     





















































                                                                               
package org.openslx.taskmanager.tasks;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CoderResult;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.openslx.satserver.util.Util;
import org.openslx.taskmanager.api.AbstractTask;

import com.google.gson.annotations.Expose;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;

public class RemoteExec extends AbstractTask
{
	
	private static final Logger LOGGER = LogManager.getLogger( RemoteExec.class );
	
	protected final static int MAX_OUTPUT_PER_CLIENT = 400000;

	@Expose
	private Client[] clients;

	@Expose
	private String sshkey;

	@Expose
	private int port;

	@Expose
	private String command;

	@Expose
	private int timeoutSeconds = 5;

	private JSch sshClient;

	private Output status = new Output();

	@Override
	protected boolean initTask()
	{
		this.setStatusObject( this.status );

		if ( Util.isEmpty( sshkey ) ) {
			status.addError( "No SSH key given" );
		}
		if ( port < 1 || port > 65535 ) {
			status.addError( "Invalid port number" );
		}
		if ( Util.isEmpty( command ) ) {
			status.addError( "No command given" );
		}
		if ( clients == null || clients.length == 0 ) {
			status.addError( "No clients given" );
		}
		if ( timeoutSeconds < 1 ) {
			status.addError( "Invalid timeout given" );
		}

		for ( Client client : clients ) {
			if ( Util.isEmpty( client.clientip ) ) {
				status.addError( "Client without host/address given!" );
				continue;
			}
			if ( !Util.isEmpty( client.machineuuid ) ) {
				status.result.put( client.machineuuid, new Result() );
			} else {
				status.result.put( client.clientip, new Result() );
			}
			client.timeoutLeft = timeoutSeconds * 1000;
		}

		if ( status.error != null )
			return false;

		JSch.setConfig( "StrictHostKeyChecking", "no" );
		sshClient = new JSch();
		try {
			sshClient.addIdentity( "", sshkey.getBytes(), null, null );
		} catch ( JSchException e ) {
			status.addError( e.toString() );
			return false;
		}

		return true;
	}

	@Override
	protected boolean execute()
	{
		ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length );
		for ( final Client client : clients ) {
			if ( Util.isEmpty( client.clientip ) )
				continue;
			tp.submit( new Runnable() {
				public void run()
				{
					String clientId = !Util.isEmpty( client.machineuuid ) ? client.machineuuid : client.clientip;
					Result clientResult = status.result.get( clientId );
					if ( clientResult == null ) {
						status.addError( "WHAT IN THE F*CK!? No clientResult for " + clientId );
						return;
					}
					Session session = null;
					ChannelExec channel = null;
					clientResult.state = State.CONNECTING;
					try {
						long st = System.currentTimeMillis();
						session = sshClient.getSession(
								Util.isEmpty( client.username ) ? "root" : client.username,
								client.clientip,
								( client.port > 0 && client.port < 65536 ) ? client.port : port
						);
						session.connect( Math.max( client.timeoutLeft / 2, 1100 ) );
						clientResult.state = State.PRE_EXEC;
						channel = (ChannelExec)session.openChannel( "exec" );
						client.timeoutLeft -= System.currentTimeMillis() - st;
						execCommand( channel, client, clientResult );
					} catch ( Exception e ) {
						clientResult.stderr.append( e.toString() + "\n" );
						clientResult.state = State.ERROR;
					} finally {
						if ( session != null ) {
							try {
								channel.disconnect();
							} catch ( Exception e ) {
							}
							try {
								session.disconnect();
							} catch ( Exception e ) {
							}
						}
					}
				}
			} );
		}
		tp.shutdown();

		try {
			tp.awaitTermination( clients.length * this.timeoutSeconds + 5, TimeUnit.SECONDS );
		} catch ( InterruptedException e ) {
			Thread.currentThread().interrupt();
			return false;
		}

		return true;
	}

	private void execCommand( ChannelExec channel, Client client, Result result ) throws JSchException, IOException
	{
		long st = System.currentTimeMillis();
		String cmd = "export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/openslx/sbin:/opt/openslx/bin\n";
		if ( !Util.isEmpty( client.machineuuid ) ) {
			cmd += "if [ \"$(cat /etc/system-uuid)\" != \"" + client.machineuuid + "\" ]; then\n"
					+ "echo 'Machine UUID mismatch' >&2\n"
					+ "exit 42\n"
					+ "fi\n";
		}
		cmd += command;
		channel.setCommand( cmd );
		LimitedBufferStream lbsOut = new LimitedBufferStream( result.stdout );
		LimitedBufferStream lbsErr = new LimitedBufferStream( result.stderr );
		channel.setOutputStream( lbsOut );
		channel.setErrStream( lbsErr );
		result.state = State.EXEC;
		channel.connect( Math.max( client.timeoutLeft - 1000, 500 ) );
		long now = System.currentTimeMillis();
		client.timeoutLeft -= now - st;
		// Read as long as we got time
		while ( client.timeoutLeft > 0 && !channel.isClosed() ) {
			st = now;
			try {
				Thread.sleep( Math.min( 250, client.timeoutLeft ) );
			} catch ( InterruptedException e ) {
				Thread.currentThread().interrupt();
				break;
			} catch ( Exception ee ) {
				LOGGER.warn( "Cannot sleep", ee );
				break;
			} finally {
				now = System.currentTimeMillis();
				client.timeoutLeft -= now - st;
			}
		}
		if ( channel.isClosed() ) {
			result.state = State.DONE;
		} else {
			result.state = State.TIMEOUT;
		}
		try {
			channel.disconnect();
		} catch ( Exception ee ) {
		}
		if ( channel.isClosed() ) {
			result.exitCode = channel.getExitStatus();
		}
	}

	static class LimitedBufferStream extends OutputStream
	{

		private final StringBuffer sb;
		private final CharBuffer cb = CharBuffer.allocate( 200 );
		private final ByteBuffer bb = ByteBuffer.allocate( 200 );
		private final CharsetDecoder decoder;

		public LimitedBufferStream(StringBuffer sb)
		{
			this.sb = sb;
			decoder = StandardCharsets.UTF_8.newDecoder();
			decoder.onMalformedInput( CodingErrorAction.REPLACE );
			decoder.onUnmappableCharacter( CodingErrorAction.REPLACE );
			decoder.reset();
		}

		@Override
		public void write( int b ) throws IOException
		{
			if ( sb.length() >= RemoteExec.MAX_OUTPUT_PER_CLIENT )
				return;
			bb.put( (byte) ( b & 0xff ) );
			decode();
		}

		@Override
		public void write( byte[] b, int off, int len ) throws IOException
		{
			if ( sb.length() >= RemoteExec.MAX_OUTPUT_PER_CLIENT )
				return;
			while ( len > 0 ) {
				int nlen = Math.min( bb.remaining(), len );
				if ( nlen <= 0 )
					throw new RuntimeException( "Empty buffer" );
				bb.put( b, off, nlen );
				off += nlen;
				len -= nlen;
				decode();
			}
		}

		private void decode()
		{
			if ( bb.position() == 0 )
				return;
			( (Buffer)bb ).limit( bb.position() );
			( (Buffer)bb ).position( 0 );
			try {
				decoder.decode( bb, cb, false );
			} catch ( Throwable t ) {
				LOGGER.warn( "Cannot convert data to UTF8", t );
			}
			bb.compact();
			( (Buffer)cb ).limit( cb.position() );
			( (Buffer)cb ).position( 0 );
			sb.append( cb.toString() );
			( (Buffer)cb ).clear();
		}

	}

	/**
	 * Output - contains additional status data of this task
	 */
	static class Output
	{
		/** UUID -> Output */
		private Map<String, Result> result = new ConcurrentHashMap<>();

		private String error;

		private synchronized void addError( String e )
		{
			if ( error == null ) {
				error = e + "\n";
			} else {
				error += e + "\n";
			}
		}
	}

	static class Client
	{
		@Expose
		private String machineuuid;
		@Expose
		private String clientip;
		@Expose
		private int port;
		@Expose
		private String username;
		/** How many ms of the given timeout are left */
		private int timeoutLeft;
	}
	
	static enum State
	{
		QUEUED,
		CONNECTING,
		PRE_EXEC,
		EXEC,
		DONE,
		ERROR,
		TIMEOUT
	}

	static class Result
	{
		State state = State.QUEUED;
		StringBuffer stdout = new StringBuffer();
		StringBuffer stderr = new StringBuffer();
		int exitCode = -1;
	}

}