package org.openslx.taskmanager.tasks; import java.io.IOException; import java.io.OutputStream; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.CharsetDecoder; import java.nio.charset.CodingErrorAction; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.openslx.satserver.util.Util; import org.openslx.taskmanager.api.AbstractTask; import org.openslx.util.PrioThreadFactory; import com.google.gson.annotations.Expose; import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSchException; import com.jcraft.jsch.Session; public class RemoteExec extends AbstractTask { private static final Logger LOGGER = LogManager.getLogger( RemoteExec.class ); protected final static int MAX_OUTPUT_PER_CLIENT = 400000; @Expose private Client[] clients; @Expose private String sshkey; @Expose private int port; @Expose private String command; @Expose private int timeoutSeconds = 5; private JSch sshClient; private Output status = new Output(); public RemoteExec() { } public RemoteExec(Client[] clients, String sshkey, int port, String command, int timeoutSeconds) { this.clients = clients; this.sshkey = sshkey; this.port = port; this.command = command; this.timeoutSeconds = timeoutSeconds; initTask(); } @Override protected boolean initTask() { this.setStatusObject( this.status ); if ( Util.isEmpty( sshkey ) ) { status.addError( "No SSH key given" ); } if ( port < 1 || port > 65535 ) { status.addError( "Invalid port number" ); } if ( Util.isEmpty( command ) ) { status.addError( "No command given" ); } if ( clients == null || clients.length == 0 ) { status.addError( "No clients given" ); } if ( timeoutSeconds < 1 ) { status.addError( "Invalid timeout given" ); } for ( Client client : clients ) { if ( Util.isEmpty( client.clientip ) ) { status.addError( "Client without host/address given!" ); continue; } if ( !Util.isEmpty( client.machineuuid ) ) { status.result.put( client.machineuuid, new Result() ); } else { status.result.put( client.clientip, new Result() ); } client.timeoutLeft = timeoutSeconds * 1000; } if ( status.error != null ) return false; JSch.setConfig( "StrictHostKeyChecking", "no" ); sshClient = new JSch(); try { sshClient.addIdentity( "", sshkey.getBytes(), null, null ); } catch ( JSchException e ) { status.addError( e.toString() ); return false; } return true; } @Override protected boolean execute() { ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length, new PrioThreadFactory( "RemExec" ) ); for ( final Client client : clients ) { if ( Util.isEmpty( client.clientip ) ) continue; tp.submit( new Runnable() { public void run() { String clientId = !Util.isEmpty( client.machineuuid ) ? client.machineuuid : client.clientip; Result clientResult = status.result.get( clientId ); if ( clientResult == null ) { status.addError( "WHAT IN THE F*CK!? No clientResult for " + clientId ); return; } Session session = null; ChannelExec channel = null; clientResult.state = State.CONNECTING; try { long st = System.currentTimeMillis(); session = sshClient.getSession( Util.isEmpty( client.username ) ? "root" : client.username, client.clientip, ( client.port > 0 && client.port < 65536 ) ? client.port : port ); session.connect( Math.max( client.timeoutLeft / 2, 1100 ) ); clientResult.state = State.PRE_EXEC; channel = (ChannelExec)session.openChannel( "exec" ); client.timeoutLeft -= System.currentTimeMillis() - st; execCommand( channel, client, clientResult ); } catch ( Exception e ) { clientResult.stderr.append( e.toString() + "\n" ); clientResult.state = State.ERROR; } finally { if ( session != null ) { try { channel.disconnect(); } catch ( Exception e ) { } try { session.disconnect(); } catch ( Exception e ) { } } } } } ); } tp.shutdown(); try { tp.awaitTermination( clients.length * this.timeoutSeconds + 5, TimeUnit.SECONDS ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); return false; } return true; } public Output getStatusObject() { return this.status; } private void execCommand( ChannelExec channel, Client client, Result result ) throws JSchException, IOException { long st = System.currentTimeMillis(); String cmd = "export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/openslx/sbin:/opt/openslx/bin\n"; if ( !Util.isEmpty( client.machineuuid ) ) { cmd += "if [ \"$(cat /etc/system-uuid)\" != \"" + client.machineuuid + "\" ]; then\n" + "echo 'Machine UUID mismatch' >&2\n" + "exit 42\n" + "fi\n"; } cmd += command; channel.setCommand( cmd ); LimitedBufferStream lbsOut = new LimitedBufferStream( result.stdout ); LimitedBufferStream lbsErr = new LimitedBufferStream( result.stderr ); channel.setOutputStream( lbsOut ); channel.setErrStream( lbsErr ); result.state = State.EXEC; channel.connect( Math.max( client.timeoutLeft - 1000, 500 ) ); long now = System.currentTimeMillis(); client.timeoutLeft -= now - st; // Read as long as we got time while ( client.timeoutLeft > 0 && !channel.isClosed() ) { st = now; try { Thread.sleep( Math.min( 250, client.timeoutLeft ) ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); break; } catch ( Exception ee ) { LOGGER.warn( "Cannot sleep", ee ); break; } finally { now = System.currentTimeMillis(); client.timeoutLeft -= now - st; } } if ( channel.isClosed() ) { result.state = State.DONE; } else { result.state = State.TIMEOUT; } try { channel.disconnect(); } catch ( Exception ee ) { } if ( channel.isClosed() ) { result.exitCode = channel.getExitStatus(); } } static class LimitedBufferStream extends OutputStream { private final StringBuffer sb; private final CharBuffer cb = CharBuffer.allocate( 200 ); private final ByteBuffer bb = ByteBuffer.allocate( 200 ); private final CharsetDecoder decoder; public LimitedBufferStream(StringBuffer sb) { this.sb = sb; decoder = StandardCharsets.UTF_8.newDecoder(); decoder.onMalformedInput( CodingErrorAction.REPLACE ); decoder.onUnmappableCharacter( CodingErrorAction.REPLACE ); decoder.reset(); } @Override public void write( int b ) throws IOException { if ( sb.length() >= RemoteExec.MAX_OUTPUT_PER_CLIENT ) return; bb.put( (byte) ( b & 0xff ) ); decode(); } @Override public void write( byte[] b, int off, int len ) throws IOException { if ( sb.length() >= RemoteExec.MAX_OUTPUT_PER_CLIENT ) return; while ( len > 0 ) { int nlen = Math.min( bb.remaining(), len ); if ( nlen <= 0 ) throw new RuntimeException( "Empty buffer" ); bb.put( b, off, nlen ); off += nlen; len -= nlen; decode(); } } private void decode() { if ( bb.position() == 0 ) return; ( (Buffer)bb ).limit( bb.position() ); ( (Buffer)bb ).position( 0 ); try { decoder.decode( bb, cb, false ); } catch ( Throwable t ) { LOGGER.warn( "Cannot convert data to UTF8", t ); } bb.compact(); ( (Buffer)cb ).limit( cb.position() ); ( (Buffer)cb ).position( 0 ); sb.append( cb.toString() ); ( (Buffer)cb ).clear(); } } /** * Output - contains additional status data of this task */ static class Output { /** UUID -> Output */ Map result = new ConcurrentHashMap<>(); String error; private synchronized void addError( String e ) { if ( error == null ) { error = e + "\n"; } else { error += e + "\n"; } } } static class Client { @Expose private String machineuuid; @Expose private String clientip; @Expose private int port; @Expose private String username; /** How many ms of the given timeout are left */ private int timeoutLeft; public Client() { } public Client(String machineuuid, String clientip, int port, String username) { this.machineuuid = machineuuid; this.clientip = clientip; this.port = port; this.username = username; } } static enum State { QUEUED, CONNECTING, PRE_EXEC, EXEC, DONE, ERROR, TIMEOUT } static class Result { State state = State.QUEUED; StringBuffer stdout = new StringBuffer(); StringBuffer stderr = new StringBuffer(); int exitCode = -1; } }