summaryrefslogtreecommitdiffstats
path: root/core/modules/qemu/runvirt-plugin-qemu/src/main/java/org/openslx/runvirt/plugin/qemu/cmdln/CommandLineArgs.java
diff options
context:
space:
mode:
Diffstat (limited to 'core/modules/qemu/runvirt-plugin-qemu/src/main/java/org/openslx/runvirt/plugin/qemu/cmdln/CommandLineArgs.java')
-rw-r--r--core/modules/qemu/runvirt-plugin-qemu/src/main/java/org/openslx/runvirt/plugin/qemu/cmdln/CommandLineArgs.java30
1 files changed, 26 insertions, 4 deletions
diff --git a/core/modules/qemu/runvirt-plugin-qemu/src/main/java/org/openslx/runvirt/plugin/qemu/cmdln/CommandLineArgs.java b/core/modules/qemu/runvirt-plugin-qemu/src/main/java/org/openslx/runvirt/plugin/qemu/cmdln/CommandLineArgs.java
index 21b11968..396c0d8c 100644
--- a/core/modules/qemu/runvirt-plugin-qemu/src/main/java/org/openslx/runvirt/plugin/qemu/cmdln/CommandLineArgs.java
+++ b/core/modules/qemu/runvirt-plugin-qemu/src/main/java/org/openslx/runvirt/plugin/qemu/cmdln/CommandLineArgs.java
@@ -13,6 +13,8 @@ import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
+import org.openslx.libvirt.domain.device.HostdevPciDeviceDescription;
+import org.openslx.runvirt.plugin.qemu.configuration.TransformationSpecificQemuPciPassthrough;
import org.openslx.util.Util;
/**
@@ -453,13 +455,33 @@ public class CommandLineArgs
}
/**
- * Returns the state whether a passthrough of a NVIDIA GPU is required.
- *
- * @return state whether a passthrough of a NVIDIA GPU is required.
+ * Returns the state whether a passthrough of a NVIDIA GPU is requested.
+ * Do this by checking the vendor ID of each PCI device that's being passed
+ * through. If one of them is nvidia, assume we're running passthrough for
+ * an nvidia GPU.
*/
public boolean isNvidiaGpuPassthroughEnabled()
{
- return this.getVmNvGpuIds0().size() > 0;
+ List<String> pciIds = this.getVmNvGpuIds0();
+ // parse PCI device description and PCI device address
+ for ( int i = 0; i < pciIds.size() - 1; i += 2 ) {
+ // parse vendor and device ID
+ HostdevPciDeviceDescription deviceDescription = null;
+ try {
+ deviceDescription = HostdevPciDeviceDescription.valueOf( pciIds.get( i ) );
+ } catch ( IllegalArgumentException e ) {
+ continue;
+ }
+
+ // validate vendor ID
+ final int vendorId = deviceDescription.getVendorId();
+ if ( TransformationSpecificQemuPciPassthrough.NVIDIA_PCI_VENDOR_ID != vendorId )
+ continue;
+
+ // Have at least one device by nvidia, just assume it's a GPU for now
+ return true;
+ }
+ return false;
}
/**