package org.openslx.taskmanager.tasks; import java.io.IOException; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import com.sshtools.common.publickey.InvalidPassphraseException; import com.sshtools.common.publickey.SshKeyUtils; import com.sshtools.common.publickey.bc.OpenSSHPrivateKeyFileBC; import com.sshtools.common.ssh.components.SshKeyPair; import org.openslx.satserver.util.Util; import org.openslx.taskmanager.api.AbstractTask; import com.google.gson.annotations.Expose; import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSchException; import com.jcraft.jsch.Session; public class RemoteExec extends AbstractTask { @Expose private Client[] clients; @Expose private String sshkey; @Expose private int port; @Expose private String command; @Expose private int timeoutSeconds = 5; private JSch sshClient; private Output status = new Output(); @Override protected boolean initTask() { this.setStatusObject( this.status ); if ( Util.isEmpty( sshkey ) ) { status.addError( "No SSH key given" ); } if ( port < 1 || port > 65535 ) { status.addError( "Invalid port number" ); } if ( Util.isEmpty( command ) ) { status.addError( "No command given" ); } if ( clients == null || clients.length == 0 ) { status.addError( "No clients given" ); } if ( timeoutSeconds < 1 ) { status.addError( "Invalid timeout given" ); } for ( Client client : clients ) { if ( Util.isEmpty( client.clientip ) ) { status.addError( "Client without host/address given!" ); continue; } if ( !Util.isEmpty( client.machineuuid ) ) { status.result.put( client.machineuuid, new Result() ); } else { status.result.put( client.clientip, new Result() ); } client.timeoutLeft = timeoutSeconds * 1000; } if ( status.error != null ) return false; // Convert ssh key to old rsa format supported by JSch try { SshKeyPair key = SshKeyUtils.getPrivateKey(sshkey, ""); sshkey = new String(new OpenSSHPrivateKeyFileBC(key, "").getFormattedKey(), StandardCharsets.UTF_8); } catch (IOException | InvalidPassphraseException e) { e.printStackTrace(); } JSch.setConfig( "StrictHostKeyChecking", "no" ); sshClient = new JSch(); try { sshClient.addIdentity( "", sshkey.getBytes(), null, null ); } catch ( JSchException e ) { status.addError( e.toString() ); return false; } return true; } @Override protected boolean execute() { ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length ); for ( final Client client : clients ) { if ( Util.isEmpty( client.clientip ) ) continue; tp.submit( new Runnable() { public void run() { String clientId = !Util.isEmpty( client.machineuuid ) ? client.machineuuid : client.clientip; Result clientResult = status.result.get( clientId ); if ( clientResult == null ) { status.addError( "WHAT IN THE F*CK!? No clientResult for " + clientId ); return; } Session session = null; ChannelExec channel = null; clientResult.state = State.CONNECTING; try { long st = System.currentTimeMillis(); session = sshClient.getSession( Util.isEmpty( client.username ) ? "root" : client.username, client.clientip, ( client.port > 0 && client.port < 65536 ) ? client.port : port ); session.connect( Math.max( client.timeoutLeft / 2, 1100 ) ); clientResult.state = State.PRE_EXEC; channel = (ChannelExec)session.openChannel( "exec" ); client.timeoutLeft -= System.currentTimeMillis() - st; execCommand( channel, client, clientResult ); } catch ( Exception e ) { clientResult.stderr.append( e.toString() + "\n" ); clientResult.state = State.ERROR; } finally { if ( session != null ) { try { channel.disconnect(); } catch ( Exception e ) { } try { session.disconnect(); } catch ( Exception e ) { } } } } } ); } tp.shutdown(); try { tp.awaitTermination( clients.length * 5, TimeUnit.SECONDS ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); return false; } return true; } private void copy( char[] cbuf, InputStreamReader in, StringBuffer out, StringBuffer other ) { try { while ( in.ready() ) { int nb = in.read( cbuf ); if ( nb == -1 ) break; out.append( cbuf, 0, nb ); if (out.length() + other.length() > 40000 ) { int trunc = 40000 - other.length(); if ( trunc > 0 && trunc < out.length() ) { out.setLength( trunc ); } break; } } } catch ( IOException e ) { e.printStackTrace(); } } private void execCommand( ChannelExec channel, Client client, Result result ) throws JSchException, IOException { long st = System.currentTimeMillis(); String cmd = "export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/openslx/sbin:/opt/openslx/bin\n"; if ( !Util.isEmpty( client.machineuuid ) ) { cmd += "if [ \"$(cat /etc/system-uuid)\" != \"" + client.machineuuid + "\" ]; then\n" + "echo 'Machine UUID mismatch' >&2\n" + "exit 42\n" + "fi\n"; } cmd += command; channel.setCommand( cmd ); InputStreamReader stdout = new InputStreamReader( channel.getInputStream(), StandardCharsets.UTF_8 ); InputStreamReader stderr = new InputStreamReader( channel.getErrStream(), StandardCharsets.UTF_8 ); result.state = State.EXEC; channel.connect( Math.max( client.timeoutLeft - 1000, 500 ) ); long now = System.currentTimeMillis(); client.timeoutLeft -= now - st; // Read as long as we got time char[] cbuf = new char[2000]; while ( client.timeoutLeft > 0 && !channel.isClosed() ) { st = now; try { Thread.sleep( Math.min( 250, client.timeoutLeft ) ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); break; } catch ( Exception ee ) { break; } finally { now = System.currentTimeMillis(); client.timeoutLeft -= now - st; } copy( cbuf, stdout, result.stdout, result.stderr ); copy( cbuf, stderr, result.stderr, result.stdout ); // Check for reasonable output size if ( result.stdout.length() + result.stderr.length() > 40000 ) { status.addError( "Truncating output of client " + client.clientip ); break; } } copy( cbuf, stdout, result.stdout, result.stderr ); copy( cbuf, stderr, result.stderr, result.stdout ); if ( channel.isClosed() ) { result.state = State.DONE; } else { result.state = State.TIMEOUT; } try { channel.disconnect(); } catch ( Exception ee ) { } if ( channel.isClosed() ) { result.exitCode = channel.getExitStatus(); } } /** * Output - contains additional status data of this task */ static class Output { /** UUID -> Output */ private Map result = new ConcurrentHashMap<>(); private String error; private synchronized void addError( String e ) { if ( error == null ) { error = e + "\n"; } else { error += e + "\n"; } } } static class Client { @Expose private String machineuuid; @Expose private String clientip; @Expose private int port; @Expose private String username; /** How many ms of the given timeout are left */ private int timeoutLeft; } static enum State { QUEUED, CONNECTING, PRE_EXEC, EXEC, DONE, ERROR, TIMEOUT } static class Result { State state = State.QUEUED; StringBuffer stdout = new StringBuffer(); StringBuffer stderr = new StringBuffer(); int exitCode = -1; } }