summaryrefslogtreecommitdiffstats
path: root/src/VBox/ValidationKit/tests/usb
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-11 08:17:27 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-11 08:17:27 +0000
commitf215e02bf85f68d3a6106c2a1f4f7f063f819064 (patch)
tree6bb5b92c046312c4e95ac2620b10ddf482d3fa8b /src/VBox/ValidationKit/tests/usb
parentInitial commit. (diff)
downloadvirtualbox-f215e02bf85f68d3a6106c2a1f4f7f063f819064.tar.xz
virtualbox-f215e02bf85f68d3a6106c2a1f4f7f063f819064.zip
Adding upstream version 7.0.14-dfsg.upstream/7.0.14-dfsg
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--src/VBox/ValidationKit/tests/usb/Makefile.kmk53
-rwxr-xr-xsrc/VBox/ValidationKit/tests/usb/tdUsb1.py590
-rwxr-xr-xsrc/VBox/ValidationKit/tests/usb/tst-utsgadget.py154
-rwxr-xr-xsrc/VBox/ValidationKit/tests/usb/usbgadget.py1478
4 files changed, 2275 insertions, 0 deletions
diff --git a/src/VBox/ValidationKit/tests/usb/Makefile.kmk b/src/VBox/ValidationKit/tests/usb/Makefile.kmk
new file mode 100644
index 00000000..3a4741cc
--- /dev/null
+++ b/src/VBox/ValidationKit/tests/usb/Makefile.kmk
@@ -0,0 +1,53 @@
+# $Id: Makefile.kmk $
+## @file
+# VirtualBox Validation Kit - USB.
+#
+
+#
+# Copyright (C) 2014-2023 Oracle and/or its affiliates.
+#
+# This file is part of VirtualBox base platform packages, as
+# available from https://www.virtualbox.org.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation, in version 3 of the
+# License.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, see <https://www.gnu.org/licenses>.
+#
+# The contents of this file may alternatively be used under the terms
+# of the Common Development and Distribution License Version 1.0
+# (CDDL), a copy of it is provided in the "COPYING.CDDL" file included
+# in the VirtualBox distribution, in which case the provisions of the
+# CDDL are applicable instead of those of the GPL.
+#
+# You may elect to license modified versions of this file under the
+# terms and conditions of either the GPL or the CDDL or both.
+#
+# SPDX-License-Identifier: GPL-3.0-only OR CDDL-1.0
+#
+
+SUB_DEPTH = ../../../../..
+include $(KBUILD_PATH)/subheader.kmk
+
+
+INSTALLS += ValidationKitTestsUsb
+ValidationKitTestsUsb_TEMPLATE = VBoxValidationKitR3
+ValidationKitTestsUsb_INST = $(INST_VALIDATIONKIT)tests/usb/
+ValidationKitTestsUsb_EXEC_SOURCES := \
+ $(PATH_SUB_CURRENT)/tdUsb1.py \
+ $(PATH_SUB_CURRENT)/usbgadget.py \
+ $(PATH_SUB_CURRENT)/tst-utsgadget.py
+
+VBOX_VALIDATIONKIT_PYTHON_SOURCES += $(ValidationKitTestsUsb_EXEC_SOURCES)
+
+$(evalcall def_vbox_validationkit_process_python_sources)
+include $(FILE_KBUILD_SUB_FOOTER)
+
diff --git a/src/VBox/ValidationKit/tests/usb/tdUsb1.py b/src/VBox/ValidationKit/tests/usb/tdUsb1.py
new file mode 100755
index 00000000..9dd20ebf
--- /dev/null
+++ b/src/VBox/ValidationKit/tests/usb/tdUsb1.py
@@ -0,0 +1,590 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# $Id: tdUsb1.py $
+
+"""
+VirtualBox Validation Kit - USB testcase and benchmark.
+"""
+
+__copyright__ = \
+"""
+Copyright (C) 2014-2023 Oracle and/or its affiliates.
+
+This file is part of VirtualBox base platform packages, as
+available from https://www.virtualbox.org.
+
+This program is free software; you can redistribute it and/or
+modify it under the terms of the GNU General Public License
+as published by the Free Software Foundation, in version 3 of the
+License.
+
+This program is distributed in the hope that it will be useful, but
+WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program; if not, see <https://www.gnu.org/licenses>.
+
+The contents of this file may alternatively be used under the terms
+of the Common Development and Distribution License Version 1.0
+(CDDL), a copy of it is provided in the "COPYING.CDDL" file included
+in the VirtualBox distribution, in which case the provisions of the
+CDDL are applicable instead of those of the GPL.
+
+You may elect to license modified versions of this file under the
+terms and conditions of either the GPL or the CDDL or both.
+
+SPDX-License-Identifier: GPL-3.0-only OR CDDL-1.0
+"""
+__version__ = "$Revision: 155244 $"
+
+
+# Standard Python imports.
+import os;
+import sys;
+import socket;
+
+# Only the main script needs to modify the path.
+try: __file__
+except: __file__ = sys.argv[0];
+g_ksValidationKitDir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))));
+sys.path.append(g_ksValidationKitDir);
+
+# Validation Kit imports.
+from testdriver import reporter;
+from testdriver import base;
+from testdriver import vbox;
+from testdriver import vboxcon;
+
+# USB gadget control import
+import usbgadget;
+
+# Python 3 hacks:
+if sys.version_info[0] >= 3:
+ xrange = range; # pylint: disable=redefined-builtin,invalid-name
+
+
+class tdUsbBenchmark(vbox.TestDriver): # pylint: disable=too-many-instance-attributes
+ """
+ USB benchmark.
+ """
+
+ # The available test devices
+ #
+ # The first key is the hostname of the host the test is running on.
+ # It contains a new dictionary with the attached gadgets based on the
+ # USB speed we want to test (Low, Full, High, Super).
+ # The parameters consist of the hostname of the gadget in the network
+ # and the hardware type.
+ kdGadgetParams = {
+ 'adaris': {
+ 'Low': ('usbtest.de.oracle.com', None),
+ 'Full': ('usbtest.de.oracle.com', None),
+ 'High': ('usbtest.de.oracle.com', None),
+ 'Super': ('usbtest.de.oracle.com', None)
+ },
+ };
+
+ # Mappings of USB controllers to supported USB device speeds.
+ kdUsbSpeedMappings = {
+ 'OHCI': ['Low', 'Full'],
+ 'EHCI': ['High'],
+ 'XHCI': ['Low', 'Full', 'High', 'Super']
+ };
+
+ # Tests currently disabled because they fail, need investigation.
+ kdUsbTestsDisabled = {
+ 'Low': [24],
+ 'Full': [24],
+ 'High': [24],
+ 'Super': [24]
+ };
+
+ def __init__(self):
+ vbox.TestDriver.__init__(self);
+ self.asRsrcs = None;
+ self.asTestVMsDef = ['tst-arch'];
+ self.asTestVMs = self.asTestVMsDef;
+ self.asSkipVMs = [];
+ self.asVirtModesDef = ['hwvirt', 'hwvirt-np', 'raw'];
+ self.asVirtModes = self.asVirtModesDef;
+ self.acCpusDef = [1, 2,];
+ self.acCpus = self.acCpusDef;
+ self.asUsbCtrlsDef = ['OHCI', 'EHCI', 'XHCI'];
+ self.asUsbCtrls = self.asUsbCtrlsDef;
+ self.asUsbSpeedDef = ['Low', 'Full', 'High', 'Super'];
+ self.asUsbSpeed = self.asUsbSpeedDef;
+ self.asUsbTestsDef = ['Compliance', 'Reattach'];
+ self.asUsbTests = self.asUsbTestsDef;
+ self.cUsbReattachCyclesDef = 100;
+ self.cUsbReattachCycles = self.cUsbReattachCyclesDef;
+ self.sHostname = socket.gethostname().lower();
+ self.sGadgetHostnameDef = 'usbtest.de.oracle.com';
+ self.uGadgetPortDef = None;
+ self.sUsbCapturePathDef = self.sScratchPath;
+ self.sUsbCapturePath = self.sUsbCapturePathDef;
+ self.fUsbCapture = False;
+
+ #
+ # Overridden methods.
+ #
+ def showUsage(self):
+ rc = vbox.TestDriver.showUsage(self);
+ reporter.log('');
+ reporter.log('tdUsb1 Options:');
+ reporter.log(' --virt-modes <m1[:m2[:]]');
+ reporter.log(' Default: %s' % (':'.join(self.asVirtModesDef)));
+ reporter.log(' --cpu-counts <c1[:c2[:]]');
+ reporter.log(' Default: %s' % (':'.join(str(c) for c in self.acCpusDef)));
+ reporter.log(' --test-vms <vm1[:vm2[:...]]>');
+ reporter.log(' Test the specified VMs in the given order. Use this to change');
+ reporter.log(' the execution order or limit the choice of VMs');
+ reporter.log(' Default: %s (all)' % (':'.join(self.asTestVMsDef)));
+ reporter.log(' --skip-vms <vm1[:vm2[:...]]>');
+ reporter.log(' Skip the specified VMs when testing.');
+ reporter.log(' --usb-ctrls <u1[:u2[:]]');
+ reporter.log(' Default: %s' % (':'.join(str(c) for c in self.asUsbCtrlsDef)));
+ reporter.log(' --usb-speed <s1[:s2[:]]');
+ reporter.log(' Default: %s' % (':'.join(str(c) for c in self.asUsbSpeedDef)));
+ reporter.log(' --usb-tests <s1[:s2[:]]');
+ reporter.log(' Default: %s' % (':'.join(str(c) for c in self.asUsbTestsDef)));
+ reporter.log(' --usb-reattach-cycles <cycles>');
+ reporter.log(' Default: %s' % (self.cUsbReattachCyclesDef));
+ reporter.log(' --hostname: <hostname>');
+ reporter.log(' Default: %s' % (self.sHostname));
+ reporter.log(' --default-gadget-host <hostname>');
+ reporter.log(' Default: %s' % (self.sGadgetHostnameDef));
+ reporter.log(' --default-gadget-port <port>');
+ reporter.log(' Default: %s' % (6042));
+ reporter.log(' --usb-capture-path <path>');
+ reporter.log(' Default: %s' % (self.sUsbCapturePathDef));
+ reporter.log(' --usb-capture');
+ reporter.log(' Whether to capture the USB traffic for each test');
+ return rc;
+
+ def parseOption(self, asArgs, iArg): # pylint: disable=too-many-branches,too-many-statements
+ if asArgs[iArg] == '--virt-modes':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--virt-modes" takes a colon separated list of modes');
+ self.asVirtModes = asArgs[iArg].split(':');
+ for s in self.asVirtModes:
+ if s not in self.asVirtModesDef:
+ raise base.InvalidOption('The "--virt-modes" value "%s" is not valid; valid values are: %s' \
+ % (s, ' '.join(self.asVirtModesDef)));
+ elif asArgs[iArg] == '--cpu-counts':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--cpu-counts" takes a colon separated list of cpu counts');
+ self.acCpus = [];
+ for s in asArgs[iArg].split(':'):
+ try: c = int(s);
+ except: raise base.InvalidOption('The "--cpu-counts" value "%s" is not an integer' % (s,));
+ if c <= 0: raise base.InvalidOption('The "--cpu-counts" value "%s" is zero or negative' % (s,));
+ self.acCpus.append(c);
+ elif asArgs[iArg] == '--test-vms':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--test-vms" takes colon separated list');
+ self.asTestVMs = asArgs[iArg].split(':');
+ for s in self.asTestVMs:
+ if s not in self.asTestVMsDef:
+ raise base.InvalidOption('The "--test-vms" value "%s" is not valid; valid values are: %s' \
+ % (s, ' '.join(self.asTestVMsDef)));
+ elif asArgs[iArg] == '--skip-vms':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--skip-vms" takes colon separated list');
+ self.asSkipVMs = asArgs[iArg].split(':');
+ for s in self.asSkipVMs:
+ if s not in self.asTestVMsDef:
+ reporter.log('warning: The "--test-vms" value "%s" does not specify any of our test VMs.' % (s));
+ elif asArgs[iArg] == '--usb-ctrls':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-ctrls" takes a colon separated list of USB controllers');
+ self.asUsbCtrls = asArgs[iArg].split(':');
+ for s in self.asUsbCtrls:
+ if s not in self.asUsbCtrlsDef:
+ reporter.log('warning: The "--usb-ctrls" value "%s" is not a valid USB controller.' % (s));
+ elif asArgs[iArg] == '--usb-speed':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-speed" takes a colon separated list of USB speeds');
+ self.asUsbSpeed = asArgs[iArg].split(':');
+ for s in self.asUsbSpeed:
+ if s not in self.asUsbSpeedDef:
+ reporter.log('warning: The "--usb-speed" value "%s" is not a valid USB speed.' % (s));
+ elif asArgs[iArg] == '--usb-tests':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-tests" takes a colon separated list of USB tests');
+ self.asUsbTests = asArgs[iArg].split(':');
+ for s in self.asUsbTests:
+ if s not in self.asUsbTestsDef:
+ reporter.log('warning: The "--usb-tests" value "%s" is not a valid USB test.' % (s));
+ elif asArgs[iArg] == '--usb-reattach-cycles':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-reattach-cycles" takes cycle count');
+ try: self.cUsbReattachCycles = int(asArgs[iArg]);
+ except: raise base.InvalidOption('The "--usb-reattach-cycles" value "%s" is not an integer' \
+ % (asArgs[iArg],));
+ if self.cUsbReattachCycles <= 0:
+ raise base.InvalidOption('The "--usb-reattach-cycles" value "%s" is zero or negative.' \
+ % (self.cUsbReattachCycles,));
+ elif asArgs[iArg] == '--hostname':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--hostname" takes a hostname');
+ self.sHostname = asArgs[iArg];
+ elif asArgs[iArg] == '--default-gadget-host':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--default-gadget-host" takes a hostname');
+ self.sGadgetHostnameDef = asArgs[iArg];
+ elif asArgs[iArg] == '--default-gadget-port':
+ iArg += 1;
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--default-gadget-port" takes port number');
+ try: self.uGadgetPortDef = int(asArgs[iArg]);
+ except: raise base.InvalidOption('The "--default-gadget-port" value "%s" is not an integer' \
+ % (asArgs[iArg],));
+ if self.uGadgetPortDef <= 0:
+ raise base.InvalidOption('The "--default-gadget-port" value "%s" is zero or negative.' \
+ % (self.uGadgetPortDef,));
+ elif asArgs[iArg] == '--usb-capture-path':
+ if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-capture-path" takes a path argument');
+ self.sUsbCapturePath = asArgs[iArg];
+ elif asArgs[iArg] == '--usb-capture':
+ self.fUsbCapture = True;
+ else:
+ return vbox.TestDriver.parseOption(self, asArgs, iArg);
+ return iArg + 1;
+
+ def completeOptions(self):
+ # Remove skipped VMs from the test list.
+ for sVM in self.asSkipVMs:
+ try: self.asTestVMs.remove(sVM);
+ except: pass;
+
+ return vbox.TestDriver.completeOptions(self);
+
+ def getResourceSet(self):
+ # Construct the resource list the first time it's queried.
+ if self.asRsrcs is None:
+ self.asRsrcs = [];
+
+ if 'tst-arch' in self.asTestVMs:
+ self.asRsrcs.append('4.2/usb/tst-arch.vdi');
+
+ return self.asRsrcs;
+
+ def actionConfig(self):
+
+ # Some stupid trickery to guess the location of the iso. ## fixme - testsuite unzip ++
+ sVBoxValidationKit_iso = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../VBoxValidationKit.iso'));
+ if not os.path.isfile(sVBoxValidationKit_iso):
+ sVBoxValidationKit_iso = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../VBoxTestSuite.iso'));
+ if not os.path.isfile(sVBoxValidationKit_iso):
+ sVBoxValidationKit_iso = '/mnt/ramdisk/vbox/svn/trunk/validationkit/VBoxValidationKit.iso';
+ if not os.path.isfile(sVBoxValidationKit_iso):
+ sVBoxValidationKit_iso = '/mnt/ramdisk/vbox/svn/trunk/testsuite/VBoxTestSuite.iso';
+ if not os.path.isfile(sVBoxValidationKit_iso):
+ sCur = os.getcwd();
+ for i in range(0, 10):
+ sVBoxValidationKit_iso = os.path.join(sCur, 'validationkit/VBoxValidationKit.iso');
+ if os.path.isfile(sVBoxValidationKit_iso):
+ break;
+ sVBoxValidationKit_iso = os.path.join(sCur, 'testsuite/VBoxTestSuite.iso');
+ if os.path.isfile(sVBoxValidationKit_iso):
+ break;
+ sCur = os.path.abspath(os.path.join(sCur, '..'));
+ if i is None: pass; # shut up pychecker/pylint.
+ if not os.path.isfile(sVBoxValidationKit_iso):
+ sVBoxValidationKit_iso = '/home/bird/validationkit/VBoxValidationKit.iso';
+ if not os.path.isfile(sVBoxValidationKit_iso):
+ sVBoxValidationKit_iso = '/home/bird/testsuite/VBoxTestSuite.iso';
+
+ # Make sure vboxapi has been imported so we can use the constants.
+ if not self.importVBoxApi():
+ return False;
+
+ #
+ # Configure the VMs we're going to use.
+ #
+
+ # Linux VMs
+ if 'tst-arch' in self.asTestVMs:
+ oVM = self.createTestVM('tst-arch', 1, '4.2/usb/tst-arch.vdi', sKind = 'ArchLinux_64', fIoApic = True, \
+ eNic0AttachType = vboxcon.NetworkAttachmentType_NAT, \
+ sDvdImage = sVBoxValidationKit_iso);
+ if oVM is None:
+ return False;
+
+ return True;
+
+ def actionExecute(self):
+ """
+ Execute the testcase.
+ """
+ fRc = self.testUsb();
+ return fRc;
+
+ def getGadgetParams(self, sHostname, sSpeed):
+ """
+ Returns the gadget hostname and port from the
+ given hostname the test is running on and device speed we want to test.
+ """
+ kdGadgetsConfigured = self.kdGadgetParams.get(sHostname);
+ if kdGadgetsConfigured is not None:
+ return kdGadgetsConfigured.get(sSpeed);
+
+ return (self.sGadgetHostnameDef, self.uGadgetPortDef);
+
+ def getCaptureFilePath(self, sUsbCtrl, sSpeed):
+ """
+ Returns capture filename from the given data.
+ """
+
+ return '%s%s%s-%s.pcap' % (self.sUsbCapturePath, os.sep, sUsbCtrl, sSpeed);
+
+ def attachUsbDeviceToVm(self, oSession, sVendorId, sProductId, iBusId,
+ sCaptureFile = None):
+ """
+ Attaches the given USB device to the VM either via a filter
+ or directly if capturing the USB traffic is enabled.
+
+ Returns True on success, False on failure.
+ """
+ fRc = False;
+ if sCaptureFile is None:
+ fRc = oSession.addUsbDeviceFilter('Compliance device', sVendorId = sVendorId, sProductId = sProductId, \
+ sPort = format(iBusId, 'x'));
+ else:
+ # Search for the correct device in the USB device list waiting for some time
+ # to let it appear.
+ iVendorId = int(sVendorId, 16);
+ iProductId = int(sProductId, 16);
+
+ # Try a few times to give VBoxSVC a chance to detect the new device.
+ for _ in xrange(5):
+ fFound = False;
+ aoUsbDevs = self.oVBoxMgr.getArray(self.oVBox.host, 'USBDevices');
+ for oUsbDev in aoUsbDevs:
+ if oUsbDev.vendorId == iVendorId \
+ and oUsbDev.productId == iProductId \
+ and oUsbDev.port == iBusId:
+ fFound = True;
+ fRc = oSession.attachUsbDevice(oUsbDev.id, sCaptureFile);
+ break;
+
+ if fFound:
+ break;
+
+ # Wait a moment until the next try.
+ self.sleep(1);
+
+ if fRc:
+ # Wait a moment to let the USB device appear
+ self.sleep(9);
+
+ return fRc;
+
+ #
+ # Test execution helpers.
+ #
+ def testUsbCompliance(self, oSession, oTxsSession, sUsbCtrl, sSpeed, sCaptureFile = None):
+ """
+ Test VirtualBoxs USB stack in a VM.
+ """
+ # Get configured USB test devices from hostname we are running on
+ sGadgetHost, uGadgetPort = self.getGadgetParams(self.sHostname, sSpeed);
+
+ oUsbGadget = usbgadget.UsbGadget();
+ reporter.log('Connecting to UTS: ' + sGadgetHost);
+ fRc = oUsbGadget.connectTo(30 * 1000, sGadgetHost, uPort = uGadgetPort, fTryConnect = True);
+ if fRc is True:
+ reporter.log('Connect succeeded');
+ self.oVBox.host.addUSBDeviceSource('USBIP', sGadgetHost, sGadgetHost + (':%s' % oUsbGadget.getUsbIpPort()), [], []);
+
+ fSuperSpeed = False;
+ if sSpeed == 'Super':
+ fSuperSpeed = True;
+
+ # Create test device gadget and a filter to attach the device automatically.
+ fRc = oUsbGadget.impersonate(usbgadget.g_ksGadgetImpersonationTest, fSuperSpeed);
+ if fRc is True:
+ iBusId, _ = oUsbGadget.getGadgetBusAndDevId();
+ fRc = self.attachUsbDeviceToVm(oSession, '0525', 'a4a0', iBusId, sCaptureFile);
+ if fRc is True:
+ tupCmdLine = ('UsbTest', );
+ # Exclude a few tests which hang and cause a timeout, need investigation.
+ lstTestsExclude = self.kdUsbTestsDisabled.get(sSpeed);
+ for iTestExclude in lstTestsExclude:
+ tupCmdLine = tupCmdLine + ('--exclude', str(iTestExclude));
+
+ fRc = self.txsRunTest(oTxsSession, 'UsbTest', 3600 * 1000, \
+ '${CDROM}/${OS/ARCH}/UsbTest${EXESUFF}', tupCmdLine);
+ if not fRc:
+ reporter.testFailure('Running USB test utility failed');
+ else:
+ reporter.testFailure('Failed to attach USB device to VM');
+ oUsbGadget.disconnectFrom();
+ else:
+ reporter.testFailure('Failed to impersonate test device');
+
+ self.oVBox.host.removeUSBDeviceSource(sGadgetHost);
+ else:
+ reporter.log('warning: Failed to connect to USB gadget');
+ fRc = None
+
+ _ = sUsbCtrl;
+ return fRc;
+
+ def testUsbReattach(self, oSession, oTxsSession, sUsbCtrl, sSpeed, sCaptureFile = None): # pylint: disable=unused-argument
+ """
+ Tests that rapid connect/disconnect cycles work.
+ """
+ # Get configured USB test devices from hostname we are running on
+ sGadgetHost, uGadgetPort = self.getGadgetParams(self.sHostname, sSpeed);
+
+ oUsbGadget = usbgadget.UsbGadget();
+ reporter.log('Connecting to UTS: ' + sGadgetHost);
+ fRc = oUsbGadget.connectTo(30 * 1000, sGadgetHost, uPort = uGadgetPort, fTryConnect = True);
+ if fRc is True:
+ self.oVBox.host.addUSBDeviceSource('USBIP', sGadgetHost, sGadgetHost + (':%s' % oUsbGadget.getUsbIpPort()), [], []);
+
+ fSuperSpeed = False;
+ if sSpeed == 'Super':
+ fSuperSpeed = True;
+
+ # Create test device gadget and a filter to attach the device automatically.
+ fRc = oUsbGadget.impersonate(usbgadget.g_ksGadgetImpersonationTest, fSuperSpeed);
+ if fRc is True:
+ iBusId, _ = oUsbGadget.getGadgetBusAndDevId();
+ fRc = self.attachUsbDeviceToVm(oSession, '0525', 'a4a0', iBusId, sCaptureFile);
+ if fRc is True:
+
+ # Wait a moment to let the USB device appear
+ self.sleep(3);
+
+ # Do a rapid disconnect reconnect cycle. Wait a second before disconnecting
+ # again or it will happen so fast that the VM can't attach the new device.
+ # @todo: Get rid of the constant wait and use an event to get notified when
+ # the device was attached.
+ for iCycle in xrange (0, self.cUsbReattachCycles):
+ fRc = oUsbGadget.disconnectUsb();
+ fRc = fRc and oUsbGadget.connectUsb();
+ if not fRc:
+ reporter.testFailure('Reattach cycle %s failed on the gadget device' % (iCycle));
+ break;
+ self.sleep(1);
+
+ else:
+ reporter.testFailure('Failed to create USB device filter');
+
+ oUsbGadget.disconnectFrom();
+ else:
+ reporter.testFailure('Failed to impersonate test device');
+ else:
+ reporter.log('warning: Failed to connect to USB gadget');
+ fRc = None
+
+ return fRc;
+
+ def testUsbOneCfg(self, sVmName, sUsbCtrl, sSpeed, sUsbTest):
+ """
+ Runs the specified VM thru one specified test.
+
+ Returns a success indicator on the general test execution. This is not
+ the actual test result.
+ """
+ oVM = self.getVmByName(sVmName);
+
+ # Reconfigure the VM
+ fRc = True;
+ oSession = self.openSession(oVM);
+ if oSession is not None:
+ fRc = fRc and oSession.enableVirtEx(True);
+ fRc = fRc and oSession.enableNestedPaging(True);
+
+ # Make sure controllers are disabled initially.
+ fRc = fRc and oSession.enableUsbOhci(False);
+ fRc = fRc and oSession.enableUsbEhci(False);
+ fRc = fRc and oSession.enableUsbXhci(False);
+
+ if sUsbCtrl == 'OHCI':
+ fRc = fRc and oSession.enableUsbOhci(True);
+ elif sUsbCtrl == 'EHCI':
+ fRc = fRc and oSession.enableUsbEhci(True);
+ elif sUsbCtrl == 'XHCI':
+ fRc = fRc and oSession.enableUsbXhci(True);
+ fRc = fRc and oSession.saveSettings();
+ fRc = oSession.close() and fRc and True; # pychecker hack.
+ oSession = None;
+ else:
+ fRc = False;
+
+ # Start up.
+ if fRc is True:
+ self.logVmInfo(oVM);
+ oSession, oTxsSession = self.startVmAndConnectToTxsViaTcp(sVmName, fCdWait = False, fNatForwardingForTxs = False);
+ if oSession is not None:
+ self.addTask(oTxsSession);
+
+ # Fudge factor - Allow the guest to finish starting up.
+ self.sleep(5);
+
+ sCaptureFile = None;
+ if self.fUsbCapture:
+ sCaptureFile = self.getCaptureFilePath(sUsbCtrl, sSpeed);
+
+ if sUsbTest == 'Compliance':
+ fRc = self.testUsbCompliance(oSession, oTxsSession, sUsbCtrl, sSpeed, sCaptureFile);
+ elif sUsbTest == 'Reattach':
+ fRc = self.testUsbReattach(oSession, oTxsSession, sUsbCtrl, sSpeed, sCaptureFile);
+
+ # cleanup.
+ self.removeTask(oTxsSession);
+ self.terminateVmBySession(oSession)
+
+ # Add the traffic dump if it exists and the test failed
+ if reporter.testErrorCount() > 0 \
+ and sCaptureFile is not None \
+ and os.path.exists(sCaptureFile):
+ reporter.addLogFile(sCaptureFile, 'misc/other', 'USB traffic dump');
+ else:
+ fRc = False;
+ return fRc;
+
+ def testUsbForOneVM(self, sVmName):
+ """
+ Runs one VM thru the various configurations.
+ """
+ fRc = False;
+ reporter.testStart(sVmName);
+ for sUsbCtrl in self.asUsbCtrls:
+ reporter.testStart(sUsbCtrl)
+ for sUsbSpeed in self.asUsbSpeed:
+ asSupportedSpeeds = self.kdUsbSpeedMappings.get(sUsbCtrl);
+ if sUsbSpeed in asSupportedSpeeds:
+ reporter.testStart(sUsbSpeed)
+ for sUsbTest in self.asUsbTests:
+ reporter.testStart(sUsbTest)
+ fRc = self.testUsbOneCfg(sVmName, sUsbCtrl, sUsbSpeed, sUsbTest);
+ reporter.testDone();
+ reporter.testDone();
+ reporter.testDone();
+ reporter.testDone();
+ return fRc;
+
+ def testUsb(self):
+ """
+ Executes USB test.
+ """
+
+ reporter.log("Running on host: " + self.sHostname);
+
+ # Loop thru the test VMs.
+ for sVM in self.asTestVMs:
+ # run test on the VM.
+ fRc = self.testUsbForOneVM(sVM);
+
+ return fRc;
+
+
+
+if __name__ == '__main__':
+ sys.exit(tdUsbBenchmark().main(sys.argv));
+
diff --git a/src/VBox/ValidationKit/tests/usb/tst-utsgadget.py b/src/VBox/ValidationKit/tests/usb/tst-utsgadget.py
new file mode 100755
index 00000000..03245e00
--- /dev/null
+++ b/src/VBox/ValidationKit/tests/usb/tst-utsgadget.py
@@ -0,0 +1,154 @@
+# -*- coding: utf-8 -*-
+# $Id: tst-utsgadget.py $
+
+"""
+Simple testcase for usbgadget2.py.
+"""
+
+__copyright__ = \
+"""
+Copyright (C) 2016-2023 Oracle and/or its affiliates.
+
+This file is part of VirtualBox base platform packages, as
+available from https://www.virtualbox.org.
+
+This program is free software; you can redistribute it and/or
+modify it under the terms of the GNU General Public License
+as published by the Free Software Foundation, in version 3 of the
+License.
+
+This program is distributed in the hope that it will be useful, but
+WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program; if not, see <https://www.gnu.org/licenses>.
+
+The contents of this file may alternatively be used under the terms
+of the Common Development and Distribution License Version 1.0
+(CDDL), a copy of it is provided in the "COPYING.CDDL" file included
+in the VirtualBox distribution, in which case the provisions of the
+CDDL are applicable instead of those of the GPL.
+
+You may elect to license modified versions of this file under the
+terms and conditions of either the GPL or the CDDL or both.
+
+SPDX-License-Identifier: GPL-3.0-only OR CDDL-1.0
+"""
+__version__ = "$Revision: 155244 $"
+
+# Standard python imports.
+import sys
+
+# Validation Kit imports.
+sys.path.insert(0, '.');
+sys.path.insert(0, '..');
+sys.path.insert(0, '../..');
+from common import utils;
+from testdriver import reporter;
+import usbgadget;
+
+
+# Python 3 hacks:
+if sys.version_info[0] >= 3:
+ long = int; # pylint: disable=redefined-builtin,invalid-name
+
+
+g_cTests = 0;
+g_cFailures = 0
+
+def boolRes(rc, fExpect = True):
+ """Checks a boolean result."""
+ global g_cTests, g_cFailures;
+ g_cTests = g_cTests + 1;
+ if isinstance(rc, bool):
+ if rc == fExpect:
+ return 'PASSED';
+ g_cFailures = g_cFailures + 1;
+ return 'FAILED';
+
+def stringRes(rc, sExpect):
+ """Checks a string result."""
+ global g_cTests, g_cFailures;
+ g_cTests = g_cTests + 1;
+ if utils.isString(rc):
+ if rc == sExpect:
+ return 'PASSED';
+ g_cFailures = g_cFailures + 1;
+ return 'FAILED';
+
+def main(asArgs): # pylint: disable=missing-docstring,too-many-locals,too-many-statements
+ cMsTimeout = long(30*1000);
+ sAddress = 'localhost';
+ uPort = None;
+ fStdTests = True;
+
+ i = 1;
+ while i < len(asArgs):
+ if asArgs[i] == '--hostname':
+ sAddress = asArgs[i + 1];
+ i = i + 2;
+ elif asArgs[i] == '--port':
+ uPort = int(asArgs[i + 1]);
+ i = i + 2;
+ elif asArgs[i] == '--timeout':
+ cMsTimeout = long(asArgs[i + 1]);
+ i = i + 2;
+ elif asArgs[i] == '--help':
+ print('tst-utsgadget.py [--hostname <addr|name>] [--port <num>] [--timeout <cMS>]');
+ return 0;
+ else:
+ print('Unknown argument: %s' % (asArgs[i],));
+ return 2;
+
+ oGadget = usbgadget.UsbGadget();
+ if uPort is None:
+ rc = oGadget.connectTo(cMsTimeout, sAddress);
+ else:
+ rc = oGadget.connectTo(cMsTimeout, sAddress, uPort = uPort);
+ if rc is False:
+ print('connectTo failed');
+ return 1;
+
+ if fStdTests:
+ rc = oGadget.getUsbIpPort() is not None;
+ print('%s: getUsbIpPort() -> %s' % (boolRes(rc), oGadget.getUsbIpPort(),));
+
+ rc = oGadget.impersonate(usbgadget.g_ksGadgetImpersonationTest);
+ print('%s: impersonate()' % (boolRes(rc),));
+
+ rc = oGadget.disconnectUsb();
+ print('%s: disconnectUsb()' % (boolRes(rc),));
+
+ rc = oGadget.connectUsb();
+ print('%s: connectUsb()' % (boolRes(rc),));
+
+ rc = oGadget.clearImpersonation();
+ print('%s: clearImpersonation()' % (boolRes(rc),));
+
+ # Test super speed (and therefore passing configuration items)
+ rc = oGadget.impersonate(usbgadget.g_ksGadgetImpersonationTest, True);
+ print('%s: impersonate(, True)' % (boolRes(rc),));
+
+ rc = oGadget.clearImpersonation();
+ print('%s: clearImpersonation()' % (boolRes(rc),));
+
+ # Done
+ rc = oGadget.disconnectFrom();
+ print('%s: disconnectFrom() -> %s' % (boolRes(rc), rc,));
+
+ if g_cFailures != 0:
+ print('tst-utsgadget.py: %u out of %u test failed' % (g_cFailures, g_cTests,));
+ return 1;
+ print('tst-utsgadget.py: all %u tests passed!' % (g_cTests,));
+ return 0;
+
+
+if __name__ == '__main__':
+ reporter.incVerbosity();
+ reporter.incVerbosity();
+ reporter.incVerbosity();
+ reporter.incVerbosity();
+ sys.exit(main(sys.argv));
+
diff --git a/src/VBox/ValidationKit/tests/usb/usbgadget.py b/src/VBox/ValidationKit/tests/usb/usbgadget.py
new file mode 100755
index 00000000..0a7c4805
--- /dev/null
+++ b/src/VBox/ValidationKit/tests/usb/usbgadget.py
@@ -0,0 +1,1478 @@
+# -*- coding: utf-8 -*-
+# $Id: usbgadget.py $
+# pylint: disable=too-many-lines
+
+"""
+UTS (USB Test Service) client.
+"""
+__copyright__ = \
+"""
+Copyright (C) 2010-2023 Oracle and/or its affiliates.
+
+This file is part of VirtualBox base platform packages, as
+available from https://www.virtualbox.org.
+
+This program is free software; you can redistribute it and/or
+modify it under the terms of the GNU General Public License
+as published by the Free Software Foundation, in version 3 of the
+License.
+
+This program is distributed in the hope that it will be useful, but
+WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program; if not, see <https://www.gnu.org/licenses>.
+
+The contents of this file may alternatively be used under the terms
+of the Common Development and Distribution License Version 1.0
+(CDDL), a copy of it is provided in the "COPYING.CDDL" file included
+in the VirtualBox distribution, in which case the provisions of the
+CDDL are applicable instead of those of the GPL.
+
+You may elect to license modified versions of this file under the
+terms and conditions of either the GPL or the CDDL or both.
+
+SPDX-License-Identifier: GPL-3.0-only OR CDDL-1.0
+"""
+__version__ = "$Revision: 155244 $"
+
+# Standard Python imports.
+import array
+import errno
+import select
+import socket
+import sys;
+import threading
+import time
+import zlib
+
+# Validation Kit imports.
+from common import utils;
+from testdriver import base;
+from testdriver import reporter;
+from testdriver.base import TdTaskBase;
+
+# Python 3 hacks:
+if sys.version_info[0] >= 3:
+ long = int; # pylint: disable=redefined-builtin,invalid-name
+
+
+## @name USB gadget impersonation string constants.
+## @{
+g_ksGadgetImpersonationInvalid = 'Invalid';
+g_ksGadgetImpersonationTest = 'Test';
+g_ksGadgetImpersonationMsd = 'Msd';
+g_ksGadgetImpersonationWebcam = 'Webcam';
+g_ksGadgetImpersonationEther = 'Ether';
+## @}
+
+## @name USB gadget type used in the UTS protocol.
+## @{
+g_kiGadgetTypeTest = 1;
+## @}
+
+## @name USB gadget access methods used in the UTS protocol.
+## @{
+g_kiGadgetAccessUsbIp = 1;
+## @}
+
+## @name USB gadget config types.
+## @{
+g_kiGadgetCfgTypeBool = 1;
+g_kiGadgetCfgTypeString = 2;
+g_kiGadgetCfgTypeUInt8 = 3;
+g_kiGadgetCfgTypeUInt16 = 4;
+g_kiGadgetCfgTypeUInt32 = 5;
+g_kiGadgetCfgTypeUInt64 = 6;
+g_kiGadgetCfgTypeInt8 = 7;
+g_kiGadgetCfgTypeInt16 = 8;
+g_kiGadgetCfgTypeInt32 = 9;
+g_kiGadgetCfgTypeInt64 = 10;
+## @}
+
+#
+# Helpers for decoding data received from the UTS.
+# These are used both the Session and Transport classes.
+#
+
+def getU64(abData, off):
+ """Get a U64 field."""
+ return abData[off] \
+ + abData[off + 1] * 256 \
+ + abData[off + 2] * 65536 \
+ + abData[off + 3] * 16777216 \
+ + abData[off + 4] * 4294967296 \
+ + abData[off + 5] * 1099511627776 \
+ + abData[off + 6] * 281474976710656 \
+ + abData[off + 7] * 72057594037927936;
+
+def getU32(abData, off):
+ """Get a U32 field."""
+ return abData[off] \
+ + abData[off + 1] * 256 \
+ + abData[off + 2] * 65536 \
+ + abData[off + 3] * 16777216;
+
+def getU16(abData, off):
+ """Get a U16 field."""
+ return abData[off] \
+ + abData[off + 1] * 256;
+
+def getU8(abData, off):
+ """Get a U8 field."""
+ return abData[off];
+
+def getSZ(abData, off, sDefault = None):
+ """
+ Get a zero-terminated string field.
+ Returns sDefault if the string is invalid.
+ """
+ cchStr = getSZLen(abData, off);
+ if cchStr >= 0:
+ abStr = abData[off:(off + cchStr)];
+ try:
+ return abStr.tostring().decode('utf_8');
+ except:
+ reporter.errorXcpt('getSZ(,%u)' % (off));
+ return sDefault;
+
+def getSZLen(abData, off):
+ """
+ Get the length of a zero-terminated string field, in bytes.
+ Returns -1 if off is beyond the data packet or not properly terminated.
+ """
+ cbData = len(abData);
+ if off >= cbData:
+ return -1;
+
+ offCur = off;
+ while abData[offCur] != 0:
+ offCur = offCur + 1;
+ if offCur >= cbData:
+ return -1;
+
+ return offCur - off;
+
+def isValidOpcodeEncoding(sOpcode):
+ """
+ Checks if the specified opcode is valid or not.
+ Returns True on success.
+ Returns False if it is invalid, details in the log.
+ """
+ sSet1 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+ sSet2 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_ ";
+ if len(sOpcode) != 8:
+ reporter.error("invalid opcode length: %s" % (len(sOpcode)));
+ return False;
+ for i in range(0, 1):
+ if sSet1.find(sOpcode[i]) < 0:
+ reporter.error("invalid opcode char #%u: %s" % (i, sOpcode));
+ return False;
+ for i in range(2, 7):
+ if sSet2.find(sOpcode[i]) < 0:
+ reporter.error("invalid opcode char #%u: %s" % (i, sOpcode));
+ return False;
+ return True;
+
+#
+# Helper for encoding data sent to the UTS.
+#
+
+def u32ToByteArray(u32):
+ """Encodes the u32 value as a little endian byte (B) array."""
+ return array.array('B', \
+ ( u32 % 256, \
+ (u32 // 256) % 256, \
+ (u32 // 65536) % 256, \
+ (u32 // 16777216) % 256) );
+
+def u16ToByteArray(u16):
+ """Encodes the u16 value as a little endian byte (B) array."""
+ return array.array('B', \
+ ( u16 % 256, \
+ (u16 // 256) % 256) );
+
+def u8ToByteArray(uint8):
+ """Encodes the u8 value as a little endian byte (B) array."""
+ return array.array('B', (uint8 % 256));
+
+def zeroByteArray(cb):
+ """Returns an array with the given size containing 0."""
+ abArray = array.array('B', (0, ));
+ cb = cb - 1;
+ for i in range(cb): # pylint: disable=unused-variable
+ abArray.append(0);
+ return abArray;
+
+def strToByteArry(sStr):
+ """Encodes the string as a little endian byte (B) array including the terminator."""
+ abArray = array.array('B');
+ sUtf8 = sStr.encode('utf_8');
+ for ch in sUtf8:
+ abArray.append(ord(ch));
+ abArray.append(0);
+ return abArray;
+
+def cfgListToByteArray(lst):
+ """Encodes the given config list as a little endian byte (B) array."""
+ abArray = array.array('B');
+ if lst is not None:
+ for t3Item in lst:
+ # Encode they key size
+ abArray.extend(u32ToByteArray(len(t3Item[0]) + 1)); # Include terminator
+ abArray.extend(u32ToByteArray(t3Item[1])) # Config type
+ abArray.extend(u32ToByteArray(len(t3Item[2]) + 1)); # Value size including temrinator.
+ abArray.extend(u32ToByteArray(0)); # Reserved field.
+
+ abArray.extend(strToByteArry(t3Item[0]));
+ abArray.extend(strToByteArry(t3Item[2]));
+
+ return abArray;
+
+class TransportBase(object):
+ """
+ Base class for the transport layer.
+ """
+
+ def __init__(self, sCaller):
+ self.sDbgCreated = '%s: %s' % (utils.getTimePrefix(), sCaller);
+ self.fDummy = 0;
+ self.abReadAheadHdr = array.array('B');
+
+ def toString(self):
+ """
+ Stringify the instance for logging and debugging.
+ """
+ return '<%s: abReadAheadHdr=%s, sDbgCreated=%s>' % (type(self).__name__, self.abReadAheadHdr, self.sDbgCreated);
+
+ def __str__(self):
+ return self.toString();
+
+ def cancelConnect(self):
+ """
+ Cancels any pending connect() call.
+ Returns None;
+ """
+ return None;
+
+ def connect(self, cMsTimeout):
+ """
+ Quietly attempts to connect to the UTS.
+
+ Returns True on success.
+ Returns False on retryable errors (no logging).
+ Returns None on fatal errors with details in the log.
+
+ Override this method, don't call super.
+ """
+ _ = cMsTimeout;
+ return False;
+
+ def disconnect(self, fQuiet = False):
+ """
+ Disconnect from the UTS.
+
+ Returns True.
+
+ Override this method, don't call super.
+ """
+ _ = fQuiet;
+ return True;
+
+ def sendBytes(self, abBuf, cMsTimeout):
+ """
+ Sends the bytes in the buffer abBuf to the UTS.
+
+ Returns True on success.
+ Returns False on failure and error details in the log.
+
+ Override this method, don't call super.
+
+ Remarks: len(abBuf) is always a multiple of 16.
+ """
+ _ = abBuf; _ = cMsTimeout;
+ return False;
+
+ def recvBytes(self, cb, cMsTimeout, fNoDataOk):
+ """
+ Receive cb number of bytes from the UTS.
+
+ Returns the bytes (array('B')) on success.
+ Returns None on failure and error details in the log.
+
+ Override this method, don't call super.
+
+ Remarks: cb is always a multiple of 16.
+ """
+ _ = cb; _ = cMsTimeout; _ = fNoDataOk;
+ return None;
+
+ def isConnectionOk(self):
+ """
+ Checks if the connection is OK.
+
+ Returns True if it is.
+ Returns False if it isn't (caller should call diconnect).
+
+ Override this method, don't call super.
+ """
+ return True;
+
+ def isRecvPending(self, cMsTimeout = 0):
+ """
+ Checks if there is incoming bytes, optionally waiting cMsTimeout
+ milliseconds for something to arrive.
+
+ Returns True if there is, False if there isn't.
+
+ Override this method, don't call super.
+ """
+ _ = cMsTimeout;
+ return False;
+
+ def sendMsgInt(self, sOpcode, cMsTimeout, abPayload = array.array('B')):
+ """
+ Sends a message (opcode + encoded payload).
+
+ Returns True on success.
+ Returns False on failure and error details in the log.
+ """
+ # Fix + check the opcode.
+ if len(sOpcode) < 2:
+ reporter.fatal('sendMsgInt: invalid opcode length: %d (\"%s\")' % (len(sOpcode), sOpcode));
+ return False;
+ sOpcode = sOpcode.ljust(8);
+ if not isValidOpcodeEncoding(sOpcode):
+ reporter.fatal('sendMsgInt: invalid opcode encoding: \"%s\"' % (sOpcode));
+ return False;
+
+ # Start construct the message.
+ cbMsg = 16 + len(abPayload);
+ abMsg = array.array('B');
+ abMsg.extend(u32ToByteArray(cbMsg));
+ abMsg.extend((0, 0, 0, 0)); # uCrc32
+ try:
+ abMsg.extend(array.array('B', \
+ ( ord(sOpcode[0]), \
+ ord(sOpcode[1]), \
+ ord(sOpcode[2]), \
+ ord(sOpcode[3]), \
+ ord(sOpcode[4]), \
+ ord(sOpcode[5]), \
+ ord(sOpcode[6]), \
+ ord(sOpcode[7]) ) ) );
+ if abPayload:
+ abMsg.extend(abPayload);
+ except:
+ reporter.fatalXcpt('sendMsgInt: packing problem...');
+ return False;
+
+ # checksum it, padd it and send it off.
+ uCrc32 = zlib.crc32(abMsg[8:]);
+ abMsg[4:8] = u32ToByteArray(uCrc32);
+
+ while len(abMsg) % 16:
+ abMsg.append(0);
+
+ reporter.log2('sendMsgInt: op=%s len=%d to=%d' % (sOpcode, len(abMsg), cMsTimeout));
+ return self.sendBytes(abMsg, cMsTimeout);
+
+ def recvMsg(self, cMsTimeout, fNoDataOk = False):
+ """
+ Receives a message from the UTS.
+
+ Returns the message three-tuple: length, opcode, payload.
+ Returns (None, None, None) on failure and error details in the log.
+ """
+
+ # Read the header.
+ if self.abReadAheadHdr:
+ assert(len(self.abReadAheadHdr) == 16);
+ abHdr = self.abReadAheadHdr;
+ self.abReadAheadHdr = array.array('B');
+ else:
+ abHdr = self.recvBytes(16, cMsTimeout, fNoDataOk); # (virtual method) # pylint: disable=assignment-from-none
+ if abHdr is None:
+ return (None, None, None);
+ if len(abHdr) != 16:
+ reporter.fatal('recvBytes(16) returns %d bytes!' % (len(abHdr)));
+ return (None, None, None);
+
+ # Unpack and validate the header.
+ cbMsg = getU32(abHdr, 0);
+ uCrc32 = getU32(abHdr, 4);
+ sOpcode = abHdr[8:16].tostring().decode('ascii');
+
+ if cbMsg < 16:
+ reporter.fatal('recvMsg: message length is out of range: %s (min 16 bytes)' % (cbMsg));
+ return (None, None, None);
+ if cbMsg > 1024*1024:
+ reporter.fatal('recvMsg: message length is out of range: %s (max 1MB)' % (cbMsg));
+ return (None, None, None);
+ if not isValidOpcodeEncoding(sOpcode):
+ reporter.fatal('recvMsg: invalid opcode \"%s\"' % (sOpcode));
+ return (None, None, None);
+
+ # Get the payload (if any), dropping the padding.
+ abPayload = array.array('B');
+ if cbMsg > 16:
+ if cbMsg % 16:
+ cbPadding = 16 - (cbMsg % 16);
+ else:
+ cbPadding = 0;
+ abPayload = self.recvBytes(cbMsg - 16 + cbPadding, cMsTimeout, False); # pylint: disable=assignment-from-none
+ if abPayload is None:
+ self.abReadAheadHdr = abHdr;
+ if not fNoDataOk :
+ reporter.log('recvMsg: failed to recv payload bytes!');
+ return (None, None, None);
+
+ while cbPadding > 0:
+ abPayload.pop();
+ cbPadding = cbPadding - 1;
+
+ # Check the CRC-32.
+ if uCrc32 != 0:
+ uActualCrc32 = zlib.crc32(abHdr[8:]);
+ if cbMsg > 16:
+ uActualCrc32 = zlib.crc32(abPayload, uActualCrc32);
+ uActualCrc32 = uActualCrc32 & 0xffffffff;
+ if uCrc32 != uActualCrc32:
+ reporter.fatal('recvMsg: crc error: expected %s, got %s' % (hex(uCrc32), hex(uActualCrc32)));
+ return (None, None, None);
+
+ reporter.log2('recvMsg: op=%s len=%d' % (sOpcode, len(abPayload)));
+ return (cbMsg, sOpcode, abPayload);
+
+ def sendMsg(self, sOpcode, cMsTimeout, aoPayload = ()):
+ """
+ Sends a message (opcode + payload tuple).
+
+ Returns True on success.
+ Returns False on failure and error details in the log.
+ Returns None if you pass the incorrectly typed parameters.
+ """
+ # Encode the payload.
+ abPayload = array.array('B');
+ for o in aoPayload:
+ try:
+ if utils.isString(o):
+ # the primitive approach...
+ sUtf8 = o.encode('utf_8');
+ for ch in sUtf8:
+ abPayload.append(ord(ch))
+ abPayload.append(0);
+ elif isinstance(o, long):
+ if o < 0 or o > 0xffffffff:
+ reporter.fatal('sendMsg: uint32_t payload is out of range: %s' % (hex(o)));
+ return None;
+ abPayload.extend(u32ToByteArray(o));
+ elif isinstance(o, int):
+ if o < 0 or o > 0xffffffff:
+ reporter.fatal('sendMsg: uint32_t payload is out of range: %s' % (hex(o)));
+ return None;
+ abPayload.extend(u32ToByteArray(o));
+ elif isinstance(o, array.array):
+ abPayload.extend(o);
+ else:
+ reporter.fatal('sendMsg: unexpected payload type: %s (%s) (aoPayload=%s)' % (type(o), o, aoPayload));
+ return None;
+ except:
+ reporter.fatalXcpt('sendMsg: screwed up the encoding code...');
+ return None;
+ return self.sendMsgInt(sOpcode, cMsTimeout, abPayload);
+
+
+class Session(TdTaskBase):
+ """
+ A USB Test Service (UTS) client session.
+ """
+
+ def __init__(self, oTransport, cMsTimeout, cMsIdleFudge, fTryConnect = False):
+ """
+ Construct a UTS session.
+
+ This starts by connecting to the UTS and will enter the signalled state
+ when connected or the timeout has been reached.
+ """
+ TdTaskBase.__init__(self, utils.getCallerName());
+ self.oTransport = oTransport;
+ self.sStatus = "";
+ self.cMsTimeout = 0;
+ self.fErr = True; # Whether to report errors as error.
+ self.msStart = 0;
+ self.oThread = None;
+ self.fnTask = self.taskDummy;
+ self.aTaskArgs = None;
+ self.oTaskRc = None;
+ self.t3oReply = (None, None, None);
+ self.fScrewedUpMsgState = False;
+ self.fTryConnect = fTryConnect;
+
+ if not self.startTask(cMsTimeout, False, "connecting", self.taskConnect, (cMsIdleFudge,)):
+ raise base.GenError("startTask failed");
+
+ def __del__(self):
+ """Make sure to cancel the task when deleted."""
+ self.cancelTask();
+
+ def toString(self):
+ return '<%s fnTask=%s, aTaskArgs=%s, sStatus=%s, oTaskRc=%s, cMsTimeout=%s,' \
+ ' msStart=%s, fTryConnect=%s, fErr=%s, fScrewedUpMsgState=%s, t3oReply=%s oTransport=%s, oThread=%s>' \
+ % (TdTaskBase.toString(self), self.fnTask, self.aTaskArgs, self.sStatus, self.oTaskRc, self.cMsTimeout,
+ self.msStart, self.fTryConnect, self.fErr, self.fScrewedUpMsgState, self.t3oReply, self.oTransport, self.oThread);
+
+ def taskDummy(self):
+ """Place holder to catch broken state handling."""
+ raise Exception();
+
+ def startTask(self, cMsTimeout, fIgnoreErrors, sStatus, fnTask, aArgs = ()):
+ """
+ Kicks of a new task.
+
+ cMsTimeout: The task timeout in milliseconds. Values less than
+ 500 ms will be adjusted to 500 ms. This means it is
+ OK to use negative value.
+ sStatus: The task status.
+ fnTask: The method that'll execute the task.
+ aArgs: Arguments to pass to fnTask.
+
+ Returns True on success, False + error in log on failure.
+ """
+ if not self.cancelTask():
+ reporter.maybeErr(not fIgnoreErrors, 'utsclient.Session.startTask: failed to cancel previous task.');
+ return False;
+
+ # Change status and make sure we're the
+ self.lockTask();
+ if self.sStatus != "":
+ self.unlockTask();
+ reporter.maybeErr(not fIgnoreErrors, 'utsclient.Session.startTask: race.');
+ return False;
+ self.sStatus = "setup";
+ self.oTaskRc = None;
+ self.t3oReply = (None, None, None);
+ self.resetTaskLocked();
+ self.unlockTask();
+
+ self.cMsTimeout = max(cMsTimeout, 500);
+ self.fErr = not fIgnoreErrors;
+ self.fnTask = fnTask;
+ self.aTaskArgs = aArgs;
+ self.oThread = threading.Thread(target=self.taskThread, args=(), name=('UTS-%s' % (sStatus)));
+ self.oThread.setDaemon(True); # pylint: disable=deprecated-method
+ self.msStart = base.timestampMilli();
+
+ self.lockTask();
+ self.sStatus = sStatus;
+ self.unlockTask();
+ self.oThread.start();
+
+ return True;
+
+ def cancelTask(self, fSync = True):
+ """
+ Attempts to cancel any pending tasks.
+ Returns success indicator (True/False).
+ """
+ self.lockTask();
+
+ if self.sStatus == "":
+ self.unlockTask();
+ return True;
+ if self.sStatus == "setup":
+ self.unlockTask();
+ return False;
+ if self.sStatus == "cancelled":
+ self.unlockTask();
+ return False;
+
+ reporter.log('utsclient: cancelling "%s"...' % (self.sStatus));
+ if self.sStatus == 'connecting':
+ self.oTransport.cancelConnect();
+
+ self.sStatus = "cancelled";
+ oThread = self.oThread;
+ self.unlockTask();
+
+ if not fSync:
+ return False;
+
+ oThread.join(61.0);
+
+ if sys.version_info < (3, 9, 0):
+ # Removed since Python 3.9.
+ return oThread.isAlive(); # pylint: disable=no-member
+ return oThread.is_alive();
+
+ def taskThread(self):
+ """
+ The task thread function.
+ This does some housekeeping activities around the real task method call.
+ """
+ if not self.isCancelled():
+ try:
+ fnTask = self.fnTask;
+ oTaskRc = fnTask(*self.aTaskArgs);
+ except:
+ reporter.fatalXcpt('taskThread', 15);
+ oTaskRc = None;
+ else:
+ reporter.log('taskThread: cancelled already');
+
+ self.lockTask();
+
+ reporter.log('taskThread: signalling task with status "%s", oTaskRc=%s' % (self.sStatus, oTaskRc));
+ self.oTaskRc = oTaskRc;
+ self.oThread = None;
+ self.sStatus = '';
+ self.signalTaskLocked();
+
+ self.unlockTask();
+ return None;
+
+ def isCancelled(self):
+ """Internal method for checking if the task has been cancelled."""
+ self.lockTask();
+ sStatus = self.sStatus;
+ self.unlockTask();
+ if sStatus == "cancelled":
+ return True;
+ return False;
+
+ def hasTimedOut(self):
+ """Internal method for checking if the task has timed out or not."""
+ cMsLeft = self.getMsLeft();
+ if cMsLeft <= 0:
+ return True;
+ return False;
+
+ def getMsLeft(self, cMsMin = 0, cMsMax = -1):
+ """Gets the time left until the timeout."""
+ cMsElapsed = base.timestampMilli() - self.msStart;
+ if cMsElapsed < 0:
+ return cMsMin;
+ cMsLeft = self.cMsTimeout - cMsElapsed;
+ if cMsLeft <= cMsMin:
+ return cMsMin;
+ if cMsLeft > cMsMax > 0:
+ return cMsMax
+ return cMsLeft;
+
+ def recvReply(self, cMsTimeout = None, fNoDataOk = False):
+ """
+ Wrapper for TransportBase.recvMsg that stashes the response away
+ so the client can inspect it later on.
+ """
+ if cMsTimeout is None:
+ cMsTimeout = self.getMsLeft(500);
+ cbMsg, sOpcode, abPayload = self.oTransport.recvMsg(cMsTimeout, fNoDataOk);
+ self.lockTask();
+ self.t3oReply = (cbMsg, sOpcode, abPayload);
+ self.unlockTask();
+ return (cbMsg, sOpcode, abPayload);
+
+ def recvAck(self, fNoDataOk = False):
+ """
+ Receives an ACK or error response from the UTS.
+
+ Returns True on success.
+ Returns False on timeout or transport error.
+ Returns (sOpcode, sDetails) tuple on failure. The opcode is stripped
+ and there are always details of some sort or another.
+ """
+ cbMsg, sOpcode, abPayload = self.recvReply(None, fNoDataOk);
+ if cbMsg is None:
+ return False;
+ sOpcode = sOpcode.strip()
+ if sOpcode == "ACK":
+ return True;
+ return (sOpcode, getSZ(abPayload, 16, sOpcode));
+
+ def recvAckLogged(self, sCommand, fNoDataOk = False):
+ """
+ Wrapper for recvAck and logging.
+ Returns True on success (ACK).
+ Returns False on time, transport error and errors signalled by UTS.
+ """
+ rc = self.recvAck(fNoDataOk);
+ if rc is not True and not fNoDataOk:
+ if rc is False:
+ reporter.maybeErr(self.fErr, 'recvAckLogged: %s transport error' % (sCommand));
+ else:
+ reporter.maybeErr(self.fErr, 'recvAckLogged: %s response was %s: %s' % (sCommand, rc[0], rc[1]));
+ rc = False;
+ return rc;
+
+ def recvTrueFalse(self, sCommand):
+ """
+ Receives a TRUE/FALSE response from the UTS.
+ Returns True on TRUE, False on FALSE and None on error/other (logged).
+ """
+ cbMsg, sOpcode, abPayload = self.recvReply();
+ if cbMsg is None:
+ reporter.maybeErr(self.fErr, 'recvAckLogged: %s transport error' % (sCommand));
+ return None;
+
+ sOpcode = sOpcode.strip()
+ if sOpcode == "TRUE":
+ return True;
+ if sOpcode == "FALSE":
+ return False;
+ reporter.maybeErr(self.fErr, 'recvAckLogged: %s response was %s: %s' % \
+ (sCommand, sOpcode, getSZ(abPayload, 16, sOpcode)));
+ return None;
+
+ def sendMsg(self, sOpcode, aoPayload = (), cMsTimeout = None):
+ """
+ Wrapper for TransportBase.sendMsg that inserts the correct timeout.
+ """
+ if cMsTimeout is None:
+ cMsTimeout = self.getMsLeft(500);
+ return self.oTransport.sendMsg(sOpcode, cMsTimeout, aoPayload);
+
+ def asyncToSync(self, fnAsync, *aArgs):
+ """
+ Wraps an asynchronous task into a synchronous operation.
+
+ Returns False on failure, task return status on success.
+ """
+ rc = fnAsync(*aArgs);
+ if rc is False:
+ reporter.log2('asyncToSync(%s): returns False (#1)' % (fnAsync));
+ return rc;
+
+ rc = self.waitForTask(self.cMsTimeout + 5000);
+ if rc is False:
+ reporter.maybeErrXcpt(self.fErr, 'asyncToSync: waitForTask failed...');
+ self.cancelTask();
+ #reporter.log2('asyncToSync(%s): returns False (#2)' % (fnAsync, rc));
+ return False;
+
+ rc = self.getResult();
+ #reporter.log2('asyncToSync(%s): returns %s' % (fnAsync, rc));
+ return rc;
+
+ #
+ # Connection tasks.
+ #
+
+ def taskConnect(self, cMsIdleFudge):
+ """Tries to connect to the UTS"""
+ while not self.isCancelled():
+ reporter.log2('taskConnect: connecting ...');
+ rc = self.oTransport.connect(self.getMsLeft(500));
+ if rc is True:
+ reporter.log('taskConnect: succeeded');
+ return self.taskGreet(cMsIdleFudge);
+ if rc is None:
+ reporter.log2('taskConnect: unable to connect');
+ return None;
+ if self.hasTimedOut():
+ reporter.log2('taskConnect: timed out');
+ if not self.fTryConnect:
+ reporter.maybeErr(self.fErr, 'taskConnect: timed out');
+ return False;
+ time.sleep(self.getMsLeft(1, 1000) / 1000.0);
+ if not self.fTryConnect:
+ reporter.maybeErr(self.fErr, 'taskConnect: cancelled');
+ return False;
+
+ def taskGreet(self, cMsIdleFudge):
+ """Greets the UTS"""
+ sHostname = socket.gethostname().lower();
+ cbFill = 68 - len(sHostname) - 1;
+ rc = self.sendMsg("HOWDY", ((1 << 16) | 0, 0x1, len(sHostname), sHostname, zeroByteArray(cbFill)));
+ if rc is True:
+ rc = self.recvAckLogged("HOWDY", self.fTryConnect);
+ if rc is True:
+ while cMsIdleFudge > 0:
+ cMsIdleFudge -= 1000;
+ time.sleep(1);
+ else:
+ self.oTransport.disconnect(self.fTryConnect);
+ return rc;
+
+ def taskBye(self):
+ """Says goodbye to the UTS"""
+ rc = self.sendMsg("BYE");
+ if rc is True:
+ rc = self.recvAckLogged("BYE");
+ self.oTransport.disconnect();
+ return rc;
+
+ #
+ # Gadget tasks.
+ #
+
+ def taskGadgetCreate(self, iGadgetType, iGadgetAccess, lstCfg = None):
+ """Creates a new gadget on UTS"""
+ cCfgItems = 0;
+ if lstCfg is not None:
+ cCfgItems = len(lstCfg);
+ fRc = self.sendMsg("GDGTCRT", (iGadgetType, iGadgetAccess, cCfgItems, 0, cfgListToByteArray(lstCfg)));
+ if fRc is True:
+ fRc = self.recvAckLogged("GDGTCRT");
+ return fRc;
+
+ def taskGadgetDestroy(self, iGadgetId):
+ """Destroys the given gadget handle on UTS"""
+ fRc = self.sendMsg("GDGTDTOR", (iGadgetId, zeroByteArray(12)));
+ if fRc is True:
+ fRc = self.recvAckLogged("GDGTDTOR");
+ return fRc;
+
+ def taskGadgetConnect(self, iGadgetId):
+ """Connects the given gadget handle on UTS"""
+ fRc = self.sendMsg("GDGTCNCT", (iGadgetId, zeroByteArray(12)));
+ if fRc is True:
+ fRc = self.recvAckLogged("GDGTCNCT");
+ return fRc;
+
+ def taskGadgetDisconnect(self, iGadgetId):
+ """Disconnects the given gadget handle from UTS"""
+ fRc = self.sendMsg("GDGTDCNT", (iGadgetId, zeroByteArray(12)));
+ if fRc is True:
+ fRc = self.recvAckLogged("GDGTDCNT");
+ return fRc;
+
+ #
+ # Public methods - generic task queries
+ #
+
+ def isSuccess(self):
+ """Returns True if the task completed successfully, otherwise False."""
+ self.lockTask();
+ sStatus = self.sStatus;
+ oTaskRc = self.oTaskRc;
+ self.unlockTask();
+ if sStatus != "":
+ return False;
+ if oTaskRc is False or oTaskRc is None:
+ return False;
+ return True;
+
+ def getResult(self):
+ """
+ Returns the result of a completed task.
+ Returns None if not completed yet or no previous task.
+ """
+ self.lockTask();
+ sStatus = self.sStatus;
+ oTaskRc = self.oTaskRc;
+ self.unlockTask();
+ if sStatus != "":
+ return None;
+ return oTaskRc;
+
+ def getLastReply(self):
+ """
+ Returns the last reply three-tuple: cbMsg, sOpcode, abPayload.
+ Returns a None, None, None three-tuple if there was no last reply.
+ """
+ self.lockTask();
+ t3oReply = self.t3oReply;
+ self.unlockTask();
+ return t3oReply;
+
+ #
+ # Public methods - connection.
+ #
+
+ def asyncDisconnect(self, cMsTimeout = 30000, fIgnoreErrors = False):
+ """
+ Initiates a disconnect task.
+
+ Returns True on success, False on failure (logged).
+
+ The task returns True on success and False on failure.
+ """
+ return self.startTask(cMsTimeout, fIgnoreErrors, "bye", self.taskBye);
+
+ def syncDisconnect(self, cMsTimeout = 30000, fIgnoreErrors = False):
+ """Synchronous version."""
+ return self.asyncToSync(self.asyncDisconnect, cMsTimeout, fIgnoreErrors);
+
+ #
+ # Public methods - gadget API
+ #
+
+ def asyncGadgetCreate(self, iGadgetType, iGadgetAccess, lstCfg = None, cMsTimeout = 30000, fIgnoreErrors = False):
+ """
+ Initiates a gadget create task.
+
+ Returns True on success, False on failure (logged).
+
+ The task returns True on success and False on failure.
+ """
+ return self.startTask(cMsTimeout, fIgnoreErrors, "GadgetCreate", self.taskGadgetCreate, \
+ (iGadgetType, iGadgetAccess, lstCfg));
+
+ def syncGadgetCreate(self, iGadgetType, iGadgetAccess, lstCfg = None, cMsTimeout = 30000, fIgnoreErrors = False):
+ """Synchronous version."""
+ return self.asyncToSync(self.asyncGadgetCreate, iGadgetType, iGadgetAccess, lstCfg, cMsTimeout, fIgnoreErrors);
+
+ def asyncGadgetDestroy(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False):
+ """
+ Initiates a gadget destroy task.
+
+ Returns True on success, False on failure (logged).
+
+ The task returns True on success and False on failure.
+ """
+ return self.startTask(cMsTimeout, fIgnoreErrors, "GadgetDestroy", self.taskGadgetDestroy, \
+ (iGadgetId, ));
+
+ def syncGadgetDestroy(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False):
+ """Synchronous version."""
+ return self.asyncToSync(self.asyncGadgetDestroy, iGadgetId, cMsTimeout, fIgnoreErrors);
+
+ def asyncGadgetConnect(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False):
+ """
+ Initiates a gadget connect task.
+
+ Returns True on success, False on failure (logged).
+
+ The task returns True on success and False on failure.
+ """
+ return self.startTask(cMsTimeout, fIgnoreErrors, "GadgetConnect", self.taskGadgetConnect, \
+ (iGadgetId, ));
+
+ def syncGadgetConnect(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False):
+ """Synchronous version."""
+ return self.asyncToSync(self.asyncGadgetConnect, iGadgetId, cMsTimeout, fIgnoreErrors);
+
+ def asyncGadgetDisconnect(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False):
+ """
+ Initiates a gadget disconnect task.
+
+ Returns True on success, False on failure (logged).
+
+ The task returns True on success and False on failure.
+ """
+ return self.startTask(cMsTimeout, fIgnoreErrors, "GadgetDisconnect", self.taskGadgetDisconnect, \
+ (iGadgetId, ));
+
+ def syncGadgetDisconnect(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False):
+ """Synchronous version."""
+ return self.asyncToSync(self.asyncGadgetDisconnect, iGadgetId, cMsTimeout, fIgnoreErrors);
+
+
+class TransportTcp(TransportBase):
+ """
+ TCP transport layer for the UTS client session class.
+ """
+
+ def __init__(self, sHostname, uPort):
+ """
+ Save the parameters. The session will call us back to make the
+ connection later on its worker thread.
+ """
+ TransportBase.__init__(self, utils.getCallerName());
+ self.sHostname = sHostname;
+ self.uPort = uPort if uPort is not None else 6042;
+ self.oSocket = None;
+ self.oWakeupW = None;
+ self.oWakeupR = None;
+ self.fConnectCanceled = False;
+ self.fIsConnecting = False;
+ self.oCv = threading.Condition();
+ self.abReadAhead = array.array('B');
+
+ def toString(self):
+ return '<%s sHostname=%s, uPort=%s, oSocket=%s,'\
+ ' fConnectCanceled=%s, fIsConnecting=%s, oCv=%s, abReadAhead=%s>' \
+ % (TransportBase.toString(self), self.sHostname, self.uPort, self.oSocket,
+ self.fConnectCanceled, self.fIsConnecting, self.oCv, self.abReadAhead);
+
+ def __isInProgressXcpt(self, oXcpt):
+ """ In progress exception? """
+ try:
+ if isinstance(oXcpt, socket.error):
+ try:
+ if oXcpt[0] == errno.EINPROGRESS:
+ return True;
+ except: pass;
+ try:
+ if oXcpt[0] == errno.EWOULDBLOCK:
+ return True;
+ if utils.getHostOs() == 'win' and oXcpt[0] == errno.WSAEWOULDBLOCK: # pylint: disable=no-member
+ return True;
+ except: pass;
+ except:
+ pass;
+ return False;
+
+ def __isWouldBlockXcpt(self, oXcpt):
+ """ Would block exception? """
+ try:
+ if isinstance(oXcpt, socket.error):
+ try:
+ if oXcpt[0] == errno.EWOULDBLOCK:
+ return True;
+ except: pass;
+ try:
+ if oXcpt[0] == errno.EAGAIN:
+ return True;
+ except: pass;
+ except:
+ pass;
+ return False;
+
+ def __isConnectionReset(self, oXcpt):
+ """ Connection reset by Peer or others. """
+ try:
+ if isinstance(oXcpt, socket.error):
+ try:
+ if oXcpt[0] == errno.ECONNRESET:
+ return True;
+ except: pass;
+ try:
+ if oXcpt[0] == errno.ENETRESET:
+ return True;
+ except: pass;
+ except:
+ pass;
+ return False;
+
+ def _closeWakeupSockets(self):
+ """ Closes the wakup sockets. Caller should own the CV. """
+ oWakeupR = self.oWakeupR;
+ self.oWakeupR = None;
+ if oWakeupR is not None:
+ oWakeupR.close();
+
+ oWakeupW = self.oWakeupW;
+ self.oWakeupW = None;
+ if oWakeupW is not None:
+ oWakeupW.close();
+
+ return None;
+
+ def cancelConnect(self):
+ # This is bad stuff.
+ self.oCv.acquire();
+ reporter.log2('TransportTcp::cancelConnect: fIsConnecting=%s oSocket=%s' % (self.fIsConnecting, self.oSocket));
+ self.fConnectCanceled = True;
+ if self.fIsConnecting:
+ oSocket = self.oSocket;
+ self.oSocket = None;
+ if oSocket is not None:
+ reporter.log2('TransportTcp::cancelConnect: closing the socket');
+ oSocket.close();
+
+ oWakeupW = self.oWakeupW;
+ self.oWakeupW = None;
+ if oWakeupW is not None:
+ reporter.log2('TransportTcp::cancelConnect: wakeup call');
+ try: oWakeupW.send('cancelled!\n');
+ except: reporter.logXcpt();
+ try: oWakeupW.shutdown(socket.SHUT_WR);
+ except: reporter.logXcpt();
+ oWakeupW.close();
+ self.oCv.release();
+
+ def _connectAsClient(self, oSocket, oWakeupR, cMsTimeout):
+ """ Connects to the UTS server as client. """
+
+ # Connect w/ timeouts.
+ rc = None;
+ try:
+ oSocket.connect((self.sHostname, self.uPort));
+ rc = True;
+ except socket.error as oXcpt:
+ iRc = oXcpt.errno;
+ if self.__isInProgressXcpt(oXcpt):
+ # Do the actual waiting.
+ reporter.log2('TransportTcp::connect: operation in progress (%s)...' % (oXcpt,));
+ try:
+ ttRc = select.select([oWakeupR], [oSocket], [oSocket, oWakeupR], cMsTimeout / 1000.0);
+ if len(ttRc[1]) + len(ttRc[2]) == 0:
+ raise socket.error(errno.ETIMEDOUT, 'select timed out');
+ iRc = oSocket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR);
+ rc = iRc == 0;
+ except socket.error as oXcpt2:
+ iRc = oXcpt2.errno;
+ except:
+ iRc = -42;
+ reporter.fatalXcpt('socket.select() on connect failed');
+
+ if rc is True:
+ pass;
+ elif iRc in (errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.EINTR, errno.ENETDOWN, errno.ENETUNREACH, errno.ETIMEDOUT):
+ rc = False; # try again.
+ else:
+ if iRc != errno.EBADF or not self.fConnectCanceled:
+ reporter.fatalXcpt('socket.connect((%s,%s)) failed; iRc=%s' % (self.sHostname, self.uPort, iRc));
+ reporter.log2('TransportTcp::connect: rc=%s iRc=%s' % (rc, iRc));
+ except:
+ reporter.fatalXcpt('socket.connect((%s,%s)) failed' % (self.sHostname, self.uPort));
+ return rc;
+
+
+ def connect(self, cMsTimeout):
+ # Create a non-blocking socket.
+ reporter.log2('TransportTcp::connect: cMsTimeout=%s sHostname=%s uPort=%s' % (cMsTimeout, self.sHostname, self.uPort));
+ try:
+ oSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0);
+ except:
+ reporter.fatalXcpt('socket.socket() failed');
+ return None;
+ try:
+ oSocket.setblocking(0);
+ except:
+ oSocket.close();
+ reporter.fatalXcpt('socket.socket() failed');
+ return None;
+
+ # Create wakeup socket pair for unix (select doesn't wake up on socket close on Linux).
+ oWakeupR = None;
+ oWakeupW = None;
+ if hasattr(socket, 'socketpair'):
+ try: (oWakeupR, oWakeupW) = socket.socketpair(); # pylint: disable=no-member
+ except: reporter.logXcpt('socket.socketpair() failed');
+
+ # Update the state.
+ self.oCv.acquire();
+ rc = None;
+ if not self.fConnectCanceled:
+ self.oSocket = oSocket;
+ self.oWakeupW = oWakeupW;
+ self.oWakeupR = oWakeupR;
+ self.fIsConnecting = True;
+ self.oCv.release();
+
+ # Try connect.
+ if oWakeupR is None:
+ oWakeupR = oSocket; # Avoid select failure.
+ rc = self._connectAsClient(oSocket, oWakeupR, cMsTimeout);
+ oSocket = None;
+
+ # Update the state and cleanup on failure/cancel.
+ self.oCv.acquire();
+ if rc is True and self.fConnectCanceled:
+ rc = False;
+ self.fIsConnecting = False;
+
+ if rc is not True:
+ if self.oSocket is not None:
+ self.oSocket.close();
+ self.oSocket = None;
+ self._closeWakeupSockets();
+ self.oCv.release();
+
+ reporter.log2('TransportTcp::connect: returning %s' % (rc,));
+ return rc;
+
+ def disconnect(self, fQuiet = False):
+ if self.oSocket is not None:
+ self.abReadAhead = array.array('B');
+
+ # Try a shutting down the socket gracefully (draining it).
+ try:
+ self.oSocket.shutdown(socket.SHUT_WR);
+ except:
+ if not fQuiet:
+ reporter.error('shutdown(SHUT_WR)');
+ try:
+ self.oSocket.setblocking(0); # just in case it's not set.
+ sData = "1";
+ while sData:
+ sData = self.oSocket.recv(16384);
+ except:
+ pass;
+
+ # Close it.
+ self.oCv.acquire();
+ try: self.oSocket.setblocking(1);
+ except: pass;
+ self.oSocket.close();
+ self.oSocket = None;
+ else:
+ self.oCv.acquire();
+ self._closeWakeupSockets();
+ self.oCv.release();
+
+ def sendBytes(self, abBuf, cMsTimeout):
+ if self.oSocket is None:
+ reporter.error('TransportTcp.sendBytes: No connection.');
+ return False;
+
+ # Try send it all.
+ try:
+ cbSent = self.oSocket.send(abBuf);
+ if cbSent == len(abBuf):
+ return True;
+ except Exception as oXcpt:
+ if not self.__isWouldBlockXcpt(oXcpt):
+ reporter.errorXcpt('TranportTcp.sendBytes: %s bytes' % (len(abBuf)));
+ return False;
+ cbSent = 0;
+
+ # Do a timed send.
+ msStart = base.timestampMilli();
+ while True:
+ cMsElapsed = base.timestampMilli() - msStart;
+ if cMsElapsed > cMsTimeout:
+ reporter.error('TranportTcp.sendBytes: %s bytes timed out (1)' % (len(abBuf)));
+ break;
+
+ # wait.
+ try:
+ ttRc = select.select([], [self.oSocket], [self.oSocket], (cMsTimeout - cMsElapsed) / 1000.0);
+ if ttRc[2] and not ttRc[1]:
+ reporter.error('TranportTcp.sendBytes: select returned with exception');
+ break;
+ if not ttRc[1]:
+ reporter.error('TranportTcp.sendBytes: %s bytes timed out (2)' % (len(abBuf)));
+ break;
+ except:
+ reporter.errorXcpt('TranportTcp.sendBytes: select failed');
+ break;
+
+ # Try send more.
+ try:
+ cbSent += self.oSocket.send(abBuf[cbSent:]);
+ if cbSent == len(abBuf):
+ return True;
+ except Exception as oXcpt:
+ if not self.__isWouldBlockXcpt(oXcpt):
+ reporter.errorXcpt('TranportTcp.sendBytes: %s bytes' % (len(abBuf)));
+ break;
+
+ return False;
+
+ def __returnReadAheadBytes(self, cb):
+ """ Internal worker for recvBytes. """
+ assert(len(self.abReadAhead) >= cb);
+ abRet = self.abReadAhead[:cb];
+ self.abReadAhead = self.abReadAhead[cb:];
+ return abRet;
+
+ def recvBytes(self, cb, cMsTimeout, fNoDataOk):
+ if self.oSocket is None:
+ reporter.error('TransportTcp.recvBytes(%s,%s): No connection.' % (cb, cMsTimeout));
+ return None;
+
+ # Try read in some more data without bothering with timeout handling first.
+ if len(self.abReadAhead) < cb:
+ try:
+ abBuf = self.oSocket.recv(cb - len(self.abReadAhead));
+ if abBuf:
+ self.abReadAhead.extend(array.array('B', abBuf));
+ except Exception as oXcpt:
+ if not self.__isWouldBlockXcpt(oXcpt):
+ reporter.errorXcpt('TranportTcp.recvBytes: 0/%s bytes' % (cb,));
+ return None;
+
+ if len(self.abReadAhead) >= cb:
+ return self.__returnReadAheadBytes(cb);
+
+ # Timeout loop.
+ msStart = base.timestampMilli();
+ while True:
+ cMsElapsed = base.timestampMilli() - msStart;
+ if cMsElapsed > cMsTimeout:
+ if not fNoDataOk or self.abReadAhead:
+ reporter.error('TranportTcp.recvBytes: %s/%s bytes timed out (1)' % (len(self.abReadAhead), cb));
+ break;
+
+ # Wait.
+ try:
+ ttRc = select.select([self.oSocket], [], [self.oSocket], (cMsTimeout - cMsElapsed) / 1000.0);
+ if ttRc[2] and not ttRc[0]:
+ reporter.error('TranportTcp.recvBytes: select returned with exception');
+ break;
+ if not ttRc[0]:
+ if not fNoDataOk or self.abReadAhead:
+ reporter.error('TranportTcp.recvBytes: %s/%s bytes timed out (2) fNoDataOk=%s'
+ % (len(self.abReadAhead), cb, fNoDataOk));
+ break;
+ except:
+ reporter.errorXcpt('TranportTcp.recvBytes: select failed');
+ break;
+
+ # Try read more.
+ try:
+ abBuf = self.oSocket.recv(cb - len(self.abReadAhead));
+ if not abBuf:
+ reporter.error('TranportTcp.recvBytes: %s/%s bytes (%s) - connection has been shut down'
+ % (len(self.abReadAhead), cb, fNoDataOk));
+ self.disconnect();
+ return None;
+
+ self.abReadAhead.extend(array.array('B', abBuf));
+
+ except Exception as oXcpt:
+ reporter.log('recv => exception %s' % (oXcpt,));
+ if not self.__isWouldBlockXcpt(oXcpt):
+ if not fNoDataOk or not self.__isConnectionReset(oXcpt) or self.abReadAhead:
+ reporter.errorXcpt('TranportTcp.recvBytes: %s/%s bytes (%s)' % (len(self.abReadAhead), cb, fNoDataOk));
+ break;
+
+ # Done?
+ if len(self.abReadAhead) >= cb:
+ return self.__returnReadAheadBytes(cb);
+
+ #reporter.log('recv => None len(self.abReadAhead) -> %d' % (len(self.abReadAhead), ));
+ return None;
+
+ def isConnectionOk(self):
+ if self.oSocket is None:
+ return False;
+ try:
+ ttRc = select.select([], [], [self.oSocket], 0.0);
+ if ttRc[2]:
+ return False;
+
+ self.oSocket.send(array.array('B')); # send zero bytes.
+ except:
+ return False;
+ return True;
+
+ def isRecvPending(self, cMsTimeout = 0):
+ try:
+ ttRc = select.select([self.oSocket], [], [], cMsTimeout / 1000.0);
+ if not ttRc[0]:
+ return False;
+ except:
+ pass;
+ return True;
+
+
+class UsbGadget(object):
+ """
+ USB Gadget control class using the USBT Test Service to talk to the external
+ board behaving like a USB device.
+ """
+
+ def __init__(self):
+ self.oUtsSession = None;
+ self.sImpersonation = g_ksGadgetImpersonationInvalid;
+ self.idGadget = None;
+ self.iBusId = None;
+ self.iDevId = None;
+ self.iUsbIpPort = None;
+
+ def clearImpersonation(self):
+ """
+ Removes the current impersonation of the gadget.
+ """
+ fRc = True;
+
+ if self.idGadget is not None:
+ fRc = self.oUtsSession.syncGadgetDestroy(self.idGadget);
+ self.idGadget = None;
+ self.iBusId = None;
+ self.iDevId = None;
+
+ return fRc;
+
+ def disconnectUsb(self):
+ """
+ Disconnects the USB gadget from the host. (USB connection not network
+ connection used for control)
+ """
+ return self.oUtsSession.syncGadgetDisconnect(self.idGadget);
+
+ def connectUsb(self):
+ """
+ Connect the USB gadget to the host.
+ """
+ return self.oUtsSession.syncGadgetConnect(self.idGadget);
+
+ def impersonate(self, sImpersonation, fSuperSpeed = False):
+ """
+ Impersonate a given device.
+ """
+
+ # Clear any previous impersonation
+ self.clearImpersonation();
+ self.sImpersonation = sImpersonation;
+
+ fRc = False;
+ if sImpersonation == g_ksGadgetImpersonationTest:
+ lstCfg = [];
+ if fSuperSpeed is True:
+ lstCfg.append( ('Gadget/SuperSpeed', g_kiGadgetCfgTypeBool, 'true') );
+ fDone = self.oUtsSession.syncGadgetCreate(g_kiGadgetTypeTest, g_kiGadgetAccessUsbIp, lstCfg);
+ if fDone is True and self.oUtsSession.isSuccess():
+ # Get the gadget ID.
+ _, _, abPayload = self.oUtsSession.getLastReply();
+
+ fRc = True;
+ self.idGadget = getU32(abPayload, 16);
+ self.iBusId = getU32(abPayload, 20);
+ self.iDevId = getU32(abPayload, 24);
+ else:
+ reporter.log('Invalid or unsupported impersonation');
+
+ return fRc;
+
+ def getUsbIpPort(self):
+ """
+ Returns the port the USB/IP server is listening on if requested,
+ None if USB/IP is not supported.
+ """
+ return self.iUsbIpPort;
+
+ def getGadgetBusAndDevId(self):
+ """
+ Returns the bus ad device ID of the gadget as a tuple.
+ """
+ return (self.iBusId, self.iDevId);
+
+ def connectTo(self, cMsTimeout, sHostname, uPort = None, fUsbIpSupport = True, cMsIdleFudge = 0, fTryConnect = False):
+ """
+ Connects to the specified target device.
+ Returns True on Success.
+ Returns False otherwise.
+ """
+ fRc = True;
+
+ # @todo
+ if fUsbIpSupport is False:
+ return False;
+
+ reporter.log2('openTcpSession(%s, %s, %s, %s)' % \
+ (cMsTimeout, sHostname, uPort, cMsIdleFudge));
+ try:
+ oTransport = TransportTcp(sHostname, uPort);
+ self.oUtsSession = Session(oTransport, cMsTimeout, cMsIdleFudge, fTryConnect);
+
+ if self.oUtsSession is not None:
+ fDone = self.oUtsSession.waitForTask(30*1000);
+ reporter.log('connect: waitForTask -> %s, result %s' % (fDone, self.oUtsSession.getResult()));
+ if fDone is True and self.oUtsSession.isSuccess():
+ # Parse the reply.
+ _, _, abPayload = self.oUtsSession.getLastReply();
+
+ if getU32(abPayload, 20) is g_kiGadgetAccessUsbIp:
+ fRc = True;
+ self.iUsbIpPort = getU32(abPayload, 24);
+ else:
+ reporter.log('Gadget doesn\'t support access over USB/IP despite being requested');
+ fRc = False;
+ else:
+ fRc = False;
+ else:
+ fRc = False;
+ except:
+ reporter.errorXcpt(None, 15);
+ return False;
+
+ return fRc;
+
+ def disconnectFrom(self):
+ """
+ Disconnects from the target device.
+ """
+ fRc = True;
+
+ self.clearImpersonation();
+ if self.oUtsSession is not None:
+ fRc = self.oUtsSession.syncDisconnect();
+
+ return fRc;