diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-11 08:17:27 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-11 08:17:27 +0000 |
commit | f215e02bf85f68d3a6106c2a1f4f7f063f819064 (patch) | |
tree | 6bb5b92c046312c4e95ac2620b10ddf482d3fa8b /src/VBox/ValidationKit/tests/usb | |
parent | Initial commit. (diff) | |
download | virtualbox-f215e02bf85f68d3a6106c2a1f4f7f063f819064.tar.xz virtualbox-f215e02bf85f68d3a6106c2a1f4f7f063f819064.zip |
Adding upstream version 7.0.14-dfsg.upstream/7.0.14-dfsg
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/VBox/ValidationKit/tests/usb')
-rw-r--r-- | src/VBox/ValidationKit/tests/usb/Makefile.kmk | 53 | ||||
-rwxr-xr-x | src/VBox/ValidationKit/tests/usb/tdUsb1.py | 590 | ||||
-rwxr-xr-x | src/VBox/ValidationKit/tests/usb/tst-utsgadget.py | 154 | ||||
-rwxr-xr-x | src/VBox/ValidationKit/tests/usb/usbgadget.py | 1478 |
4 files changed, 2275 insertions, 0 deletions
diff --git a/src/VBox/ValidationKit/tests/usb/Makefile.kmk b/src/VBox/ValidationKit/tests/usb/Makefile.kmk new file mode 100644 index 00000000..3a4741cc --- /dev/null +++ b/src/VBox/ValidationKit/tests/usb/Makefile.kmk @@ -0,0 +1,53 @@ +# $Id: Makefile.kmk $ +## @file +# VirtualBox Validation Kit - USB. +# + +# +# Copyright (C) 2014-2023 Oracle and/or its affiliates. +# +# This file is part of VirtualBox base platform packages, as +# available from https://www.virtualbox.org. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation, in version 3 of the +# License. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, see <https://www.gnu.org/licenses>. +# +# The contents of this file may alternatively be used under the terms +# of the Common Development and Distribution License Version 1.0 +# (CDDL), a copy of it is provided in the "COPYING.CDDL" file included +# in the VirtualBox distribution, in which case the provisions of the +# CDDL are applicable instead of those of the GPL. +# +# You may elect to license modified versions of this file under the +# terms and conditions of either the GPL or the CDDL or both. +# +# SPDX-License-Identifier: GPL-3.0-only OR CDDL-1.0 +# + +SUB_DEPTH = ../../../../.. +include $(KBUILD_PATH)/subheader.kmk + + +INSTALLS += ValidationKitTestsUsb +ValidationKitTestsUsb_TEMPLATE = VBoxValidationKitR3 +ValidationKitTestsUsb_INST = $(INST_VALIDATIONKIT)tests/usb/ +ValidationKitTestsUsb_EXEC_SOURCES := \ + $(PATH_SUB_CURRENT)/tdUsb1.py \ + $(PATH_SUB_CURRENT)/usbgadget.py \ + $(PATH_SUB_CURRENT)/tst-utsgadget.py + +VBOX_VALIDATIONKIT_PYTHON_SOURCES += $(ValidationKitTestsUsb_EXEC_SOURCES) + +$(evalcall def_vbox_validationkit_process_python_sources) +include $(FILE_KBUILD_SUB_FOOTER) + diff --git a/src/VBox/ValidationKit/tests/usb/tdUsb1.py b/src/VBox/ValidationKit/tests/usb/tdUsb1.py new file mode 100755 index 00000000..9dd20ebf --- /dev/null +++ b/src/VBox/ValidationKit/tests/usb/tdUsb1.py @@ -0,0 +1,590 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# $Id: tdUsb1.py $ + +""" +VirtualBox Validation Kit - USB testcase and benchmark. +""" + +__copyright__ = \ +""" +Copyright (C) 2014-2023 Oracle and/or its affiliates. + +This file is part of VirtualBox base platform packages, as +available from https://www.virtualbox.org. + +This program is free software; you can redistribute it and/or +modify it under the terms of the GNU General Public License +as published by the Free Software Foundation, in version 3 of the +License. + +This program is distributed in the hope that it will be useful, but +WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program; if not, see <https://www.gnu.org/licenses>. + +The contents of this file may alternatively be used under the terms +of the Common Development and Distribution License Version 1.0 +(CDDL), a copy of it is provided in the "COPYING.CDDL" file included +in the VirtualBox distribution, in which case the provisions of the +CDDL are applicable instead of those of the GPL. + +You may elect to license modified versions of this file under the +terms and conditions of either the GPL or the CDDL or both. + +SPDX-License-Identifier: GPL-3.0-only OR CDDL-1.0 +""" +__version__ = "$Revision: 155244 $" + + +# Standard Python imports. +import os; +import sys; +import socket; + +# Only the main script needs to modify the path. +try: __file__ +except: __file__ = sys.argv[0]; +g_ksValidationKitDir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))); +sys.path.append(g_ksValidationKitDir); + +# Validation Kit imports. +from testdriver import reporter; +from testdriver import base; +from testdriver import vbox; +from testdriver import vboxcon; + +# USB gadget control import +import usbgadget; + +# Python 3 hacks: +if sys.version_info[0] >= 3: + xrange = range; # pylint: disable=redefined-builtin,invalid-name + + +class tdUsbBenchmark(vbox.TestDriver): # pylint: disable=too-many-instance-attributes + """ + USB benchmark. + """ + + # The available test devices + # + # The first key is the hostname of the host the test is running on. + # It contains a new dictionary with the attached gadgets based on the + # USB speed we want to test (Low, Full, High, Super). + # The parameters consist of the hostname of the gadget in the network + # and the hardware type. + kdGadgetParams = { + 'adaris': { + 'Low': ('usbtest.de.oracle.com', None), + 'Full': ('usbtest.de.oracle.com', None), + 'High': ('usbtest.de.oracle.com', None), + 'Super': ('usbtest.de.oracle.com', None) + }, + }; + + # Mappings of USB controllers to supported USB device speeds. + kdUsbSpeedMappings = { + 'OHCI': ['Low', 'Full'], + 'EHCI': ['High'], + 'XHCI': ['Low', 'Full', 'High', 'Super'] + }; + + # Tests currently disabled because they fail, need investigation. + kdUsbTestsDisabled = { + 'Low': [24], + 'Full': [24], + 'High': [24], + 'Super': [24] + }; + + def __init__(self): + vbox.TestDriver.__init__(self); + self.asRsrcs = None; + self.asTestVMsDef = ['tst-arch']; + self.asTestVMs = self.asTestVMsDef; + self.asSkipVMs = []; + self.asVirtModesDef = ['hwvirt', 'hwvirt-np', 'raw']; + self.asVirtModes = self.asVirtModesDef; + self.acCpusDef = [1, 2,]; + self.acCpus = self.acCpusDef; + self.asUsbCtrlsDef = ['OHCI', 'EHCI', 'XHCI']; + self.asUsbCtrls = self.asUsbCtrlsDef; + self.asUsbSpeedDef = ['Low', 'Full', 'High', 'Super']; + self.asUsbSpeed = self.asUsbSpeedDef; + self.asUsbTestsDef = ['Compliance', 'Reattach']; + self.asUsbTests = self.asUsbTestsDef; + self.cUsbReattachCyclesDef = 100; + self.cUsbReattachCycles = self.cUsbReattachCyclesDef; + self.sHostname = socket.gethostname().lower(); + self.sGadgetHostnameDef = 'usbtest.de.oracle.com'; + self.uGadgetPortDef = None; + self.sUsbCapturePathDef = self.sScratchPath; + self.sUsbCapturePath = self.sUsbCapturePathDef; + self.fUsbCapture = False; + + # + # Overridden methods. + # + def showUsage(self): + rc = vbox.TestDriver.showUsage(self); + reporter.log(''); + reporter.log('tdUsb1 Options:'); + reporter.log(' --virt-modes <m1[:m2[:]]'); + reporter.log(' Default: %s' % (':'.join(self.asVirtModesDef))); + reporter.log(' --cpu-counts <c1[:c2[:]]'); + reporter.log(' Default: %s' % (':'.join(str(c) for c in self.acCpusDef))); + reporter.log(' --test-vms <vm1[:vm2[:...]]>'); + reporter.log(' Test the specified VMs in the given order. Use this to change'); + reporter.log(' the execution order or limit the choice of VMs'); + reporter.log(' Default: %s (all)' % (':'.join(self.asTestVMsDef))); + reporter.log(' --skip-vms <vm1[:vm2[:...]]>'); + reporter.log(' Skip the specified VMs when testing.'); + reporter.log(' --usb-ctrls <u1[:u2[:]]'); + reporter.log(' Default: %s' % (':'.join(str(c) for c in self.asUsbCtrlsDef))); + reporter.log(' --usb-speed <s1[:s2[:]]'); + reporter.log(' Default: %s' % (':'.join(str(c) for c in self.asUsbSpeedDef))); + reporter.log(' --usb-tests <s1[:s2[:]]'); + reporter.log(' Default: %s' % (':'.join(str(c) for c in self.asUsbTestsDef))); + reporter.log(' --usb-reattach-cycles <cycles>'); + reporter.log(' Default: %s' % (self.cUsbReattachCyclesDef)); + reporter.log(' --hostname: <hostname>'); + reporter.log(' Default: %s' % (self.sHostname)); + reporter.log(' --default-gadget-host <hostname>'); + reporter.log(' Default: %s' % (self.sGadgetHostnameDef)); + reporter.log(' --default-gadget-port <port>'); + reporter.log(' Default: %s' % (6042)); + reporter.log(' --usb-capture-path <path>'); + reporter.log(' Default: %s' % (self.sUsbCapturePathDef)); + reporter.log(' --usb-capture'); + reporter.log(' Whether to capture the USB traffic for each test'); + return rc; + + def parseOption(self, asArgs, iArg): # pylint: disable=too-many-branches,too-many-statements + if asArgs[iArg] == '--virt-modes': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--virt-modes" takes a colon separated list of modes'); + self.asVirtModes = asArgs[iArg].split(':'); + for s in self.asVirtModes: + if s not in self.asVirtModesDef: + raise base.InvalidOption('The "--virt-modes" value "%s" is not valid; valid values are: %s' \ + % (s, ' '.join(self.asVirtModesDef))); + elif asArgs[iArg] == '--cpu-counts': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--cpu-counts" takes a colon separated list of cpu counts'); + self.acCpus = []; + for s in asArgs[iArg].split(':'): + try: c = int(s); + except: raise base.InvalidOption('The "--cpu-counts" value "%s" is not an integer' % (s,)); + if c <= 0: raise base.InvalidOption('The "--cpu-counts" value "%s" is zero or negative' % (s,)); + self.acCpus.append(c); + elif asArgs[iArg] == '--test-vms': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--test-vms" takes colon separated list'); + self.asTestVMs = asArgs[iArg].split(':'); + for s in self.asTestVMs: + if s not in self.asTestVMsDef: + raise base.InvalidOption('The "--test-vms" value "%s" is not valid; valid values are: %s' \ + % (s, ' '.join(self.asTestVMsDef))); + elif asArgs[iArg] == '--skip-vms': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--skip-vms" takes colon separated list'); + self.asSkipVMs = asArgs[iArg].split(':'); + for s in self.asSkipVMs: + if s not in self.asTestVMsDef: + reporter.log('warning: The "--test-vms" value "%s" does not specify any of our test VMs.' % (s)); + elif asArgs[iArg] == '--usb-ctrls': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-ctrls" takes a colon separated list of USB controllers'); + self.asUsbCtrls = asArgs[iArg].split(':'); + for s in self.asUsbCtrls: + if s not in self.asUsbCtrlsDef: + reporter.log('warning: The "--usb-ctrls" value "%s" is not a valid USB controller.' % (s)); + elif asArgs[iArg] == '--usb-speed': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-speed" takes a colon separated list of USB speeds'); + self.asUsbSpeed = asArgs[iArg].split(':'); + for s in self.asUsbSpeed: + if s not in self.asUsbSpeedDef: + reporter.log('warning: The "--usb-speed" value "%s" is not a valid USB speed.' % (s)); + elif asArgs[iArg] == '--usb-tests': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-tests" takes a colon separated list of USB tests'); + self.asUsbTests = asArgs[iArg].split(':'); + for s in self.asUsbTests: + if s not in self.asUsbTestsDef: + reporter.log('warning: The "--usb-tests" value "%s" is not a valid USB test.' % (s)); + elif asArgs[iArg] == '--usb-reattach-cycles': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-reattach-cycles" takes cycle count'); + try: self.cUsbReattachCycles = int(asArgs[iArg]); + except: raise base.InvalidOption('The "--usb-reattach-cycles" value "%s" is not an integer' \ + % (asArgs[iArg],)); + if self.cUsbReattachCycles <= 0: + raise base.InvalidOption('The "--usb-reattach-cycles" value "%s" is zero or negative.' \ + % (self.cUsbReattachCycles,)); + elif asArgs[iArg] == '--hostname': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--hostname" takes a hostname'); + self.sHostname = asArgs[iArg]; + elif asArgs[iArg] == '--default-gadget-host': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--default-gadget-host" takes a hostname'); + self.sGadgetHostnameDef = asArgs[iArg]; + elif asArgs[iArg] == '--default-gadget-port': + iArg += 1; + if iArg >= len(asArgs): raise base.InvalidOption('The "--default-gadget-port" takes port number'); + try: self.uGadgetPortDef = int(asArgs[iArg]); + except: raise base.InvalidOption('The "--default-gadget-port" value "%s" is not an integer' \ + % (asArgs[iArg],)); + if self.uGadgetPortDef <= 0: + raise base.InvalidOption('The "--default-gadget-port" value "%s" is zero or negative.' \ + % (self.uGadgetPortDef,)); + elif asArgs[iArg] == '--usb-capture-path': + if iArg >= len(asArgs): raise base.InvalidOption('The "--usb-capture-path" takes a path argument'); + self.sUsbCapturePath = asArgs[iArg]; + elif asArgs[iArg] == '--usb-capture': + self.fUsbCapture = True; + else: + return vbox.TestDriver.parseOption(self, asArgs, iArg); + return iArg + 1; + + def completeOptions(self): + # Remove skipped VMs from the test list. + for sVM in self.asSkipVMs: + try: self.asTestVMs.remove(sVM); + except: pass; + + return vbox.TestDriver.completeOptions(self); + + def getResourceSet(self): + # Construct the resource list the first time it's queried. + if self.asRsrcs is None: + self.asRsrcs = []; + + if 'tst-arch' in self.asTestVMs: + self.asRsrcs.append('4.2/usb/tst-arch.vdi'); + + return self.asRsrcs; + + def actionConfig(self): + + # Some stupid trickery to guess the location of the iso. ## fixme - testsuite unzip ++ + sVBoxValidationKit_iso = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../VBoxValidationKit.iso')); + if not os.path.isfile(sVBoxValidationKit_iso): + sVBoxValidationKit_iso = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../VBoxTestSuite.iso')); + if not os.path.isfile(sVBoxValidationKit_iso): + sVBoxValidationKit_iso = '/mnt/ramdisk/vbox/svn/trunk/validationkit/VBoxValidationKit.iso'; + if not os.path.isfile(sVBoxValidationKit_iso): + sVBoxValidationKit_iso = '/mnt/ramdisk/vbox/svn/trunk/testsuite/VBoxTestSuite.iso'; + if not os.path.isfile(sVBoxValidationKit_iso): + sCur = os.getcwd(); + for i in range(0, 10): + sVBoxValidationKit_iso = os.path.join(sCur, 'validationkit/VBoxValidationKit.iso'); + if os.path.isfile(sVBoxValidationKit_iso): + break; + sVBoxValidationKit_iso = os.path.join(sCur, 'testsuite/VBoxTestSuite.iso'); + if os.path.isfile(sVBoxValidationKit_iso): + break; + sCur = os.path.abspath(os.path.join(sCur, '..')); + if i is None: pass; # shut up pychecker/pylint. + if not os.path.isfile(sVBoxValidationKit_iso): + sVBoxValidationKit_iso = '/home/bird/validationkit/VBoxValidationKit.iso'; + if not os.path.isfile(sVBoxValidationKit_iso): + sVBoxValidationKit_iso = '/home/bird/testsuite/VBoxTestSuite.iso'; + + # Make sure vboxapi has been imported so we can use the constants. + if not self.importVBoxApi(): + return False; + + # + # Configure the VMs we're going to use. + # + + # Linux VMs + if 'tst-arch' in self.asTestVMs: + oVM = self.createTestVM('tst-arch', 1, '4.2/usb/tst-arch.vdi', sKind = 'ArchLinux_64', fIoApic = True, \ + eNic0AttachType = vboxcon.NetworkAttachmentType_NAT, \ + sDvdImage = sVBoxValidationKit_iso); + if oVM is None: + return False; + + return True; + + def actionExecute(self): + """ + Execute the testcase. + """ + fRc = self.testUsb(); + return fRc; + + def getGadgetParams(self, sHostname, sSpeed): + """ + Returns the gadget hostname and port from the + given hostname the test is running on and device speed we want to test. + """ + kdGadgetsConfigured = self.kdGadgetParams.get(sHostname); + if kdGadgetsConfigured is not None: + return kdGadgetsConfigured.get(sSpeed); + + return (self.sGadgetHostnameDef, self.uGadgetPortDef); + + def getCaptureFilePath(self, sUsbCtrl, sSpeed): + """ + Returns capture filename from the given data. + """ + + return '%s%s%s-%s.pcap' % (self.sUsbCapturePath, os.sep, sUsbCtrl, sSpeed); + + def attachUsbDeviceToVm(self, oSession, sVendorId, sProductId, iBusId, + sCaptureFile = None): + """ + Attaches the given USB device to the VM either via a filter + or directly if capturing the USB traffic is enabled. + + Returns True on success, False on failure. + """ + fRc = False; + if sCaptureFile is None: + fRc = oSession.addUsbDeviceFilter('Compliance device', sVendorId = sVendorId, sProductId = sProductId, \ + sPort = format(iBusId, 'x')); + else: + # Search for the correct device in the USB device list waiting for some time + # to let it appear. + iVendorId = int(sVendorId, 16); + iProductId = int(sProductId, 16); + + # Try a few times to give VBoxSVC a chance to detect the new device. + for _ in xrange(5): + fFound = False; + aoUsbDevs = self.oVBoxMgr.getArray(self.oVBox.host, 'USBDevices'); + for oUsbDev in aoUsbDevs: + if oUsbDev.vendorId == iVendorId \ + and oUsbDev.productId == iProductId \ + and oUsbDev.port == iBusId: + fFound = True; + fRc = oSession.attachUsbDevice(oUsbDev.id, sCaptureFile); + break; + + if fFound: + break; + + # Wait a moment until the next try. + self.sleep(1); + + if fRc: + # Wait a moment to let the USB device appear + self.sleep(9); + + return fRc; + + # + # Test execution helpers. + # + def testUsbCompliance(self, oSession, oTxsSession, sUsbCtrl, sSpeed, sCaptureFile = None): + """ + Test VirtualBoxs USB stack in a VM. + """ + # Get configured USB test devices from hostname we are running on + sGadgetHost, uGadgetPort = self.getGadgetParams(self.sHostname, sSpeed); + + oUsbGadget = usbgadget.UsbGadget(); + reporter.log('Connecting to UTS: ' + sGadgetHost); + fRc = oUsbGadget.connectTo(30 * 1000, sGadgetHost, uPort = uGadgetPort, fTryConnect = True); + if fRc is True: + reporter.log('Connect succeeded'); + self.oVBox.host.addUSBDeviceSource('USBIP', sGadgetHost, sGadgetHost + (':%s' % oUsbGadget.getUsbIpPort()), [], []); + + fSuperSpeed = False; + if sSpeed == 'Super': + fSuperSpeed = True; + + # Create test device gadget and a filter to attach the device automatically. + fRc = oUsbGadget.impersonate(usbgadget.g_ksGadgetImpersonationTest, fSuperSpeed); + if fRc is True: + iBusId, _ = oUsbGadget.getGadgetBusAndDevId(); + fRc = self.attachUsbDeviceToVm(oSession, '0525', 'a4a0', iBusId, sCaptureFile); + if fRc is True: + tupCmdLine = ('UsbTest', ); + # Exclude a few tests which hang and cause a timeout, need investigation. + lstTestsExclude = self.kdUsbTestsDisabled.get(sSpeed); + for iTestExclude in lstTestsExclude: + tupCmdLine = tupCmdLine + ('--exclude', str(iTestExclude)); + + fRc = self.txsRunTest(oTxsSession, 'UsbTest', 3600 * 1000, \ + '${CDROM}/${OS/ARCH}/UsbTest${EXESUFF}', tupCmdLine); + if not fRc: + reporter.testFailure('Running USB test utility failed'); + else: + reporter.testFailure('Failed to attach USB device to VM'); + oUsbGadget.disconnectFrom(); + else: + reporter.testFailure('Failed to impersonate test device'); + + self.oVBox.host.removeUSBDeviceSource(sGadgetHost); + else: + reporter.log('warning: Failed to connect to USB gadget'); + fRc = None + + _ = sUsbCtrl; + return fRc; + + def testUsbReattach(self, oSession, oTxsSession, sUsbCtrl, sSpeed, sCaptureFile = None): # pylint: disable=unused-argument + """ + Tests that rapid connect/disconnect cycles work. + """ + # Get configured USB test devices from hostname we are running on + sGadgetHost, uGadgetPort = self.getGadgetParams(self.sHostname, sSpeed); + + oUsbGadget = usbgadget.UsbGadget(); + reporter.log('Connecting to UTS: ' + sGadgetHost); + fRc = oUsbGadget.connectTo(30 * 1000, sGadgetHost, uPort = uGadgetPort, fTryConnect = True); + if fRc is True: + self.oVBox.host.addUSBDeviceSource('USBIP', sGadgetHost, sGadgetHost + (':%s' % oUsbGadget.getUsbIpPort()), [], []); + + fSuperSpeed = False; + if sSpeed == 'Super': + fSuperSpeed = True; + + # Create test device gadget and a filter to attach the device automatically. + fRc = oUsbGadget.impersonate(usbgadget.g_ksGadgetImpersonationTest, fSuperSpeed); + if fRc is True: + iBusId, _ = oUsbGadget.getGadgetBusAndDevId(); + fRc = self.attachUsbDeviceToVm(oSession, '0525', 'a4a0', iBusId, sCaptureFile); + if fRc is True: + + # Wait a moment to let the USB device appear + self.sleep(3); + + # Do a rapid disconnect reconnect cycle. Wait a second before disconnecting + # again or it will happen so fast that the VM can't attach the new device. + # @todo: Get rid of the constant wait and use an event to get notified when + # the device was attached. + for iCycle in xrange (0, self.cUsbReattachCycles): + fRc = oUsbGadget.disconnectUsb(); + fRc = fRc and oUsbGadget.connectUsb(); + if not fRc: + reporter.testFailure('Reattach cycle %s failed on the gadget device' % (iCycle)); + break; + self.sleep(1); + + else: + reporter.testFailure('Failed to create USB device filter'); + + oUsbGadget.disconnectFrom(); + else: + reporter.testFailure('Failed to impersonate test device'); + else: + reporter.log('warning: Failed to connect to USB gadget'); + fRc = None + + return fRc; + + def testUsbOneCfg(self, sVmName, sUsbCtrl, sSpeed, sUsbTest): + """ + Runs the specified VM thru one specified test. + + Returns a success indicator on the general test execution. This is not + the actual test result. + """ + oVM = self.getVmByName(sVmName); + + # Reconfigure the VM + fRc = True; + oSession = self.openSession(oVM); + if oSession is not None: + fRc = fRc and oSession.enableVirtEx(True); + fRc = fRc and oSession.enableNestedPaging(True); + + # Make sure controllers are disabled initially. + fRc = fRc and oSession.enableUsbOhci(False); + fRc = fRc and oSession.enableUsbEhci(False); + fRc = fRc and oSession.enableUsbXhci(False); + + if sUsbCtrl == 'OHCI': + fRc = fRc and oSession.enableUsbOhci(True); + elif sUsbCtrl == 'EHCI': + fRc = fRc and oSession.enableUsbEhci(True); + elif sUsbCtrl == 'XHCI': + fRc = fRc and oSession.enableUsbXhci(True); + fRc = fRc and oSession.saveSettings(); + fRc = oSession.close() and fRc and True; # pychecker hack. + oSession = None; + else: + fRc = False; + + # Start up. + if fRc is True: + self.logVmInfo(oVM); + oSession, oTxsSession = self.startVmAndConnectToTxsViaTcp(sVmName, fCdWait = False, fNatForwardingForTxs = False); + if oSession is not None: + self.addTask(oTxsSession); + + # Fudge factor - Allow the guest to finish starting up. + self.sleep(5); + + sCaptureFile = None; + if self.fUsbCapture: + sCaptureFile = self.getCaptureFilePath(sUsbCtrl, sSpeed); + + if sUsbTest == 'Compliance': + fRc = self.testUsbCompliance(oSession, oTxsSession, sUsbCtrl, sSpeed, sCaptureFile); + elif sUsbTest == 'Reattach': + fRc = self.testUsbReattach(oSession, oTxsSession, sUsbCtrl, sSpeed, sCaptureFile); + + # cleanup. + self.removeTask(oTxsSession); + self.terminateVmBySession(oSession) + + # Add the traffic dump if it exists and the test failed + if reporter.testErrorCount() > 0 \ + and sCaptureFile is not None \ + and os.path.exists(sCaptureFile): + reporter.addLogFile(sCaptureFile, 'misc/other', 'USB traffic dump'); + else: + fRc = False; + return fRc; + + def testUsbForOneVM(self, sVmName): + """ + Runs one VM thru the various configurations. + """ + fRc = False; + reporter.testStart(sVmName); + for sUsbCtrl in self.asUsbCtrls: + reporter.testStart(sUsbCtrl) + for sUsbSpeed in self.asUsbSpeed: + asSupportedSpeeds = self.kdUsbSpeedMappings.get(sUsbCtrl); + if sUsbSpeed in asSupportedSpeeds: + reporter.testStart(sUsbSpeed) + for sUsbTest in self.asUsbTests: + reporter.testStart(sUsbTest) + fRc = self.testUsbOneCfg(sVmName, sUsbCtrl, sUsbSpeed, sUsbTest); + reporter.testDone(); + reporter.testDone(); + reporter.testDone(); + reporter.testDone(); + return fRc; + + def testUsb(self): + """ + Executes USB test. + """ + + reporter.log("Running on host: " + self.sHostname); + + # Loop thru the test VMs. + for sVM in self.asTestVMs: + # run test on the VM. + fRc = self.testUsbForOneVM(sVM); + + return fRc; + + + +if __name__ == '__main__': + sys.exit(tdUsbBenchmark().main(sys.argv)); + diff --git a/src/VBox/ValidationKit/tests/usb/tst-utsgadget.py b/src/VBox/ValidationKit/tests/usb/tst-utsgadget.py new file mode 100755 index 00000000..03245e00 --- /dev/null +++ b/src/VBox/ValidationKit/tests/usb/tst-utsgadget.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +# $Id: tst-utsgadget.py $ + +""" +Simple testcase for usbgadget2.py. +""" + +__copyright__ = \ +""" +Copyright (C) 2016-2023 Oracle and/or its affiliates. + +This file is part of VirtualBox base platform packages, as +available from https://www.virtualbox.org. + +This program is free software; you can redistribute it and/or +modify it under the terms of the GNU General Public License +as published by the Free Software Foundation, in version 3 of the +License. + +This program is distributed in the hope that it will be useful, but +WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program; if not, see <https://www.gnu.org/licenses>. + +The contents of this file may alternatively be used under the terms +of the Common Development and Distribution License Version 1.0 +(CDDL), a copy of it is provided in the "COPYING.CDDL" file included +in the VirtualBox distribution, in which case the provisions of the +CDDL are applicable instead of those of the GPL. + +You may elect to license modified versions of this file under the +terms and conditions of either the GPL or the CDDL or both. + +SPDX-License-Identifier: GPL-3.0-only OR CDDL-1.0 +""" +__version__ = "$Revision: 155244 $" + +# Standard python imports. +import sys + +# Validation Kit imports. +sys.path.insert(0, '.'); +sys.path.insert(0, '..'); +sys.path.insert(0, '../..'); +from common import utils; +from testdriver import reporter; +import usbgadget; + + +# Python 3 hacks: +if sys.version_info[0] >= 3: + long = int; # pylint: disable=redefined-builtin,invalid-name + + +g_cTests = 0; +g_cFailures = 0 + +def boolRes(rc, fExpect = True): + """Checks a boolean result.""" + global g_cTests, g_cFailures; + g_cTests = g_cTests + 1; + if isinstance(rc, bool): + if rc == fExpect: + return 'PASSED'; + g_cFailures = g_cFailures + 1; + return 'FAILED'; + +def stringRes(rc, sExpect): + """Checks a string result.""" + global g_cTests, g_cFailures; + g_cTests = g_cTests + 1; + if utils.isString(rc): + if rc == sExpect: + return 'PASSED'; + g_cFailures = g_cFailures + 1; + return 'FAILED'; + +def main(asArgs): # pylint: disable=missing-docstring,too-many-locals,too-many-statements + cMsTimeout = long(30*1000); + sAddress = 'localhost'; + uPort = None; + fStdTests = True; + + i = 1; + while i < len(asArgs): + if asArgs[i] == '--hostname': + sAddress = asArgs[i + 1]; + i = i + 2; + elif asArgs[i] == '--port': + uPort = int(asArgs[i + 1]); + i = i + 2; + elif asArgs[i] == '--timeout': + cMsTimeout = long(asArgs[i + 1]); + i = i + 2; + elif asArgs[i] == '--help': + print('tst-utsgadget.py [--hostname <addr|name>] [--port <num>] [--timeout <cMS>]'); + return 0; + else: + print('Unknown argument: %s' % (asArgs[i],)); + return 2; + + oGadget = usbgadget.UsbGadget(); + if uPort is None: + rc = oGadget.connectTo(cMsTimeout, sAddress); + else: + rc = oGadget.connectTo(cMsTimeout, sAddress, uPort = uPort); + if rc is False: + print('connectTo failed'); + return 1; + + if fStdTests: + rc = oGadget.getUsbIpPort() is not None; + print('%s: getUsbIpPort() -> %s' % (boolRes(rc), oGadget.getUsbIpPort(),)); + + rc = oGadget.impersonate(usbgadget.g_ksGadgetImpersonationTest); + print('%s: impersonate()' % (boolRes(rc),)); + + rc = oGadget.disconnectUsb(); + print('%s: disconnectUsb()' % (boolRes(rc),)); + + rc = oGadget.connectUsb(); + print('%s: connectUsb()' % (boolRes(rc),)); + + rc = oGadget.clearImpersonation(); + print('%s: clearImpersonation()' % (boolRes(rc),)); + + # Test super speed (and therefore passing configuration items) + rc = oGadget.impersonate(usbgadget.g_ksGadgetImpersonationTest, True); + print('%s: impersonate(, True)' % (boolRes(rc),)); + + rc = oGadget.clearImpersonation(); + print('%s: clearImpersonation()' % (boolRes(rc),)); + + # Done + rc = oGadget.disconnectFrom(); + print('%s: disconnectFrom() -> %s' % (boolRes(rc), rc,)); + + if g_cFailures != 0: + print('tst-utsgadget.py: %u out of %u test failed' % (g_cFailures, g_cTests,)); + return 1; + print('tst-utsgadget.py: all %u tests passed!' % (g_cTests,)); + return 0; + + +if __name__ == '__main__': + reporter.incVerbosity(); + reporter.incVerbosity(); + reporter.incVerbosity(); + reporter.incVerbosity(); + sys.exit(main(sys.argv)); + diff --git a/src/VBox/ValidationKit/tests/usb/usbgadget.py b/src/VBox/ValidationKit/tests/usb/usbgadget.py new file mode 100755 index 00000000..0a7c4805 --- /dev/null +++ b/src/VBox/ValidationKit/tests/usb/usbgadget.py @@ -0,0 +1,1478 @@ +# -*- coding: utf-8 -*- +# $Id: usbgadget.py $ +# pylint: disable=too-many-lines + +""" +UTS (USB Test Service) client. +""" +__copyright__ = \ +""" +Copyright (C) 2010-2023 Oracle and/or its affiliates. + +This file is part of VirtualBox base platform packages, as +available from https://www.virtualbox.org. + +This program is free software; you can redistribute it and/or +modify it under the terms of the GNU General Public License +as published by the Free Software Foundation, in version 3 of the +License. + +This program is distributed in the hope that it will be useful, but +WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program; if not, see <https://www.gnu.org/licenses>. + +The contents of this file may alternatively be used under the terms +of the Common Development and Distribution License Version 1.0 +(CDDL), a copy of it is provided in the "COPYING.CDDL" file included +in the VirtualBox distribution, in which case the provisions of the +CDDL are applicable instead of those of the GPL. + +You may elect to license modified versions of this file under the +terms and conditions of either the GPL or the CDDL or both. + +SPDX-License-Identifier: GPL-3.0-only OR CDDL-1.0 +""" +__version__ = "$Revision: 155244 $" + +# Standard Python imports. +import array +import errno +import select +import socket +import sys; +import threading +import time +import zlib + +# Validation Kit imports. +from common import utils; +from testdriver import base; +from testdriver import reporter; +from testdriver.base import TdTaskBase; + +# Python 3 hacks: +if sys.version_info[0] >= 3: + long = int; # pylint: disable=redefined-builtin,invalid-name + + +## @name USB gadget impersonation string constants. +## @{ +g_ksGadgetImpersonationInvalid = 'Invalid'; +g_ksGadgetImpersonationTest = 'Test'; +g_ksGadgetImpersonationMsd = 'Msd'; +g_ksGadgetImpersonationWebcam = 'Webcam'; +g_ksGadgetImpersonationEther = 'Ether'; +## @} + +## @name USB gadget type used in the UTS protocol. +## @{ +g_kiGadgetTypeTest = 1; +## @} + +## @name USB gadget access methods used in the UTS protocol. +## @{ +g_kiGadgetAccessUsbIp = 1; +## @} + +## @name USB gadget config types. +## @{ +g_kiGadgetCfgTypeBool = 1; +g_kiGadgetCfgTypeString = 2; +g_kiGadgetCfgTypeUInt8 = 3; +g_kiGadgetCfgTypeUInt16 = 4; +g_kiGadgetCfgTypeUInt32 = 5; +g_kiGadgetCfgTypeUInt64 = 6; +g_kiGadgetCfgTypeInt8 = 7; +g_kiGadgetCfgTypeInt16 = 8; +g_kiGadgetCfgTypeInt32 = 9; +g_kiGadgetCfgTypeInt64 = 10; +## @} + +# +# Helpers for decoding data received from the UTS. +# These are used both the Session and Transport classes. +# + +def getU64(abData, off): + """Get a U64 field.""" + return abData[off] \ + + abData[off + 1] * 256 \ + + abData[off + 2] * 65536 \ + + abData[off + 3] * 16777216 \ + + abData[off + 4] * 4294967296 \ + + abData[off + 5] * 1099511627776 \ + + abData[off + 6] * 281474976710656 \ + + abData[off + 7] * 72057594037927936; + +def getU32(abData, off): + """Get a U32 field.""" + return abData[off] \ + + abData[off + 1] * 256 \ + + abData[off + 2] * 65536 \ + + abData[off + 3] * 16777216; + +def getU16(abData, off): + """Get a U16 field.""" + return abData[off] \ + + abData[off + 1] * 256; + +def getU8(abData, off): + """Get a U8 field.""" + return abData[off]; + +def getSZ(abData, off, sDefault = None): + """ + Get a zero-terminated string field. + Returns sDefault if the string is invalid. + """ + cchStr = getSZLen(abData, off); + if cchStr >= 0: + abStr = abData[off:(off + cchStr)]; + try: + return abStr.tostring().decode('utf_8'); + except: + reporter.errorXcpt('getSZ(,%u)' % (off)); + return sDefault; + +def getSZLen(abData, off): + """ + Get the length of a zero-terminated string field, in bytes. + Returns -1 if off is beyond the data packet or not properly terminated. + """ + cbData = len(abData); + if off >= cbData: + return -1; + + offCur = off; + while abData[offCur] != 0: + offCur = offCur + 1; + if offCur >= cbData: + return -1; + + return offCur - off; + +def isValidOpcodeEncoding(sOpcode): + """ + Checks if the specified opcode is valid or not. + Returns True on success. + Returns False if it is invalid, details in the log. + """ + sSet1 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + sSet2 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_ "; + if len(sOpcode) != 8: + reporter.error("invalid opcode length: %s" % (len(sOpcode))); + return False; + for i in range(0, 1): + if sSet1.find(sOpcode[i]) < 0: + reporter.error("invalid opcode char #%u: %s" % (i, sOpcode)); + return False; + for i in range(2, 7): + if sSet2.find(sOpcode[i]) < 0: + reporter.error("invalid opcode char #%u: %s" % (i, sOpcode)); + return False; + return True; + +# +# Helper for encoding data sent to the UTS. +# + +def u32ToByteArray(u32): + """Encodes the u32 value as a little endian byte (B) array.""" + return array.array('B', \ + ( u32 % 256, \ + (u32 // 256) % 256, \ + (u32 // 65536) % 256, \ + (u32 // 16777216) % 256) ); + +def u16ToByteArray(u16): + """Encodes the u16 value as a little endian byte (B) array.""" + return array.array('B', \ + ( u16 % 256, \ + (u16 // 256) % 256) ); + +def u8ToByteArray(uint8): + """Encodes the u8 value as a little endian byte (B) array.""" + return array.array('B', (uint8 % 256)); + +def zeroByteArray(cb): + """Returns an array with the given size containing 0.""" + abArray = array.array('B', (0, )); + cb = cb - 1; + for i in range(cb): # pylint: disable=unused-variable + abArray.append(0); + return abArray; + +def strToByteArry(sStr): + """Encodes the string as a little endian byte (B) array including the terminator.""" + abArray = array.array('B'); + sUtf8 = sStr.encode('utf_8'); + for ch in sUtf8: + abArray.append(ord(ch)); + abArray.append(0); + return abArray; + +def cfgListToByteArray(lst): + """Encodes the given config list as a little endian byte (B) array.""" + abArray = array.array('B'); + if lst is not None: + for t3Item in lst: + # Encode they key size + abArray.extend(u32ToByteArray(len(t3Item[0]) + 1)); # Include terminator + abArray.extend(u32ToByteArray(t3Item[1])) # Config type + abArray.extend(u32ToByteArray(len(t3Item[2]) + 1)); # Value size including temrinator. + abArray.extend(u32ToByteArray(0)); # Reserved field. + + abArray.extend(strToByteArry(t3Item[0])); + abArray.extend(strToByteArry(t3Item[2])); + + return abArray; + +class TransportBase(object): + """ + Base class for the transport layer. + """ + + def __init__(self, sCaller): + self.sDbgCreated = '%s: %s' % (utils.getTimePrefix(), sCaller); + self.fDummy = 0; + self.abReadAheadHdr = array.array('B'); + + def toString(self): + """ + Stringify the instance for logging and debugging. + """ + return '<%s: abReadAheadHdr=%s, sDbgCreated=%s>' % (type(self).__name__, self.abReadAheadHdr, self.sDbgCreated); + + def __str__(self): + return self.toString(); + + def cancelConnect(self): + """ + Cancels any pending connect() call. + Returns None; + """ + return None; + + def connect(self, cMsTimeout): + """ + Quietly attempts to connect to the UTS. + + Returns True on success. + Returns False on retryable errors (no logging). + Returns None on fatal errors with details in the log. + + Override this method, don't call super. + """ + _ = cMsTimeout; + return False; + + def disconnect(self, fQuiet = False): + """ + Disconnect from the UTS. + + Returns True. + + Override this method, don't call super. + """ + _ = fQuiet; + return True; + + def sendBytes(self, abBuf, cMsTimeout): + """ + Sends the bytes in the buffer abBuf to the UTS. + + Returns True on success. + Returns False on failure and error details in the log. + + Override this method, don't call super. + + Remarks: len(abBuf) is always a multiple of 16. + """ + _ = abBuf; _ = cMsTimeout; + return False; + + def recvBytes(self, cb, cMsTimeout, fNoDataOk): + """ + Receive cb number of bytes from the UTS. + + Returns the bytes (array('B')) on success. + Returns None on failure and error details in the log. + + Override this method, don't call super. + + Remarks: cb is always a multiple of 16. + """ + _ = cb; _ = cMsTimeout; _ = fNoDataOk; + return None; + + def isConnectionOk(self): + """ + Checks if the connection is OK. + + Returns True if it is. + Returns False if it isn't (caller should call diconnect). + + Override this method, don't call super. + """ + return True; + + def isRecvPending(self, cMsTimeout = 0): + """ + Checks if there is incoming bytes, optionally waiting cMsTimeout + milliseconds for something to arrive. + + Returns True if there is, False if there isn't. + + Override this method, don't call super. + """ + _ = cMsTimeout; + return False; + + def sendMsgInt(self, sOpcode, cMsTimeout, abPayload = array.array('B')): + """ + Sends a message (opcode + encoded payload). + + Returns True on success. + Returns False on failure and error details in the log. + """ + # Fix + check the opcode. + if len(sOpcode) < 2: + reporter.fatal('sendMsgInt: invalid opcode length: %d (\"%s\")' % (len(sOpcode), sOpcode)); + return False; + sOpcode = sOpcode.ljust(8); + if not isValidOpcodeEncoding(sOpcode): + reporter.fatal('sendMsgInt: invalid opcode encoding: \"%s\"' % (sOpcode)); + return False; + + # Start construct the message. + cbMsg = 16 + len(abPayload); + abMsg = array.array('B'); + abMsg.extend(u32ToByteArray(cbMsg)); + abMsg.extend((0, 0, 0, 0)); # uCrc32 + try: + abMsg.extend(array.array('B', \ + ( ord(sOpcode[0]), \ + ord(sOpcode[1]), \ + ord(sOpcode[2]), \ + ord(sOpcode[3]), \ + ord(sOpcode[4]), \ + ord(sOpcode[5]), \ + ord(sOpcode[6]), \ + ord(sOpcode[7]) ) ) ); + if abPayload: + abMsg.extend(abPayload); + except: + reporter.fatalXcpt('sendMsgInt: packing problem...'); + return False; + + # checksum it, padd it and send it off. + uCrc32 = zlib.crc32(abMsg[8:]); + abMsg[4:8] = u32ToByteArray(uCrc32); + + while len(abMsg) % 16: + abMsg.append(0); + + reporter.log2('sendMsgInt: op=%s len=%d to=%d' % (sOpcode, len(abMsg), cMsTimeout)); + return self.sendBytes(abMsg, cMsTimeout); + + def recvMsg(self, cMsTimeout, fNoDataOk = False): + """ + Receives a message from the UTS. + + Returns the message three-tuple: length, opcode, payload. + Returns (None, None, None) on failure and error details in the log. + """ + + # Read the header. + if self.abReadAheadHdr: + assert(len(self.abReadAheadHdr) == 16); + abHdr = self.abReadAheadHdr; + self.abReadAheadHdr = array.array('B'); + else: + abHdr = self.recvBytes(16, cMsTimeout, fNoDataOk); # (virtual method) # pylint: disable=assignment-from-none + if abHdr is None: + return (None, None, None); + if len(abHdr) != 16: + reporter.fatal('recvBytes(16) returns %d bytes!' % (len(abHdr))); + return (None, None, None); + + # Unpack and validate the header. + cbMsg = getU32(abHdr, 0); + uCrc32 = getU32(abHdr, 4); + sOpcode = abHdr[8:16].tostring().decode('ascii'); + + if cbMsg < 16: + reporter.fatal('recvMsg: message length is out of range: %s (min 16 bytes)' % (cbMsg)); + return (None, None, None); + if cbMsg > 1024*1024: + reporter.fatal('recvMsg: message length is out of range: %s (max 1MB)' % (cbMsg)); + return (None, None, None); + if not isValidOpcodeEncoding(sOpcode): + reporter.fatal('recvMsg: invalid opcode \"%s\"' % (sOpcode)); + return (None, None, None); + + # Get the payload (if any), dropping the padding. + abPayload = array.array('B'); + if cbMsg > 16: + if cbMsg % 16: + cbPadding = 16 - (cbMsg % 16); + else: + cbPadding = 0; + abPayload = self.recvBytes(cbMsg - 16 + cbPadding, cMsTimeout, False); # pylint: disable=assignment-from-none + if abPayload is None: + self.abReadAheadHdr = abHdr; + if not fNoDataOk : + reporter.log('recvMsg: failed to recv payload bytes!'); + return (None, None, None); + + while cbPadding > 0: + abPayload.pop(); + cbPadding = cbPadding - 1; + + # Check the CRC-32. + if uCrc32 != 0: + uActualCrc32 = zlib.crc32(abHdr[8:]); + if cbMsg > 16: + uActualCrc32 = zlib.crc32(abPayload, uActualCrc32); + uActualCrc32 = uActualCrc32 & 0xffffffff; + if uCrc32 != uActualCrc32: + reporter.fatal('recvMsg: crc error: expected %s, got %s' % (hex(uCrc32), hex(uActualCrc32))); + return (None, None, None); + + reporter.log2('recvMsg: op=%s len=%d' % (sOpcode, len(abPayload))); + return (cbMsg, sOpcode, abPayload); + + def sendMsg(self, sOpcode, cMsTimeout, aoPayload = ()): + """ + Sends a message (opcode + payload tuple). + + Returns True on success. + Returns False on failure and error details in the log. + Returns None if you pass the incorrectly typed parameters. + """ + # Encode the payload. + abPayload = array.array('B'); + for o in aoPayload: + try: + if utils.isString(o): + # the primitive approach... + sUtf8 = o.encode('utf_8'); + for ch in sUtf8: + abPayload.append(ord(ch)) + abPayload.append(0); + elif isinstance(o, long): + if o < 0 or o > 0xffffffff: + reporter.fatal('sendMsg: uint32_t payload is out of range: %s' % (hex(o))); + return None; + abPayload.extend(u32ToByteArray(o)); + elif isinstance(o, int): + if o < 0 or o > 0xffffffff: + reporter.fatal('sendMsg: uint32_t payload is out of range: %s' % (hex(o))); + return None; + abPayload.extend(u32ToByteArray(o)); + elif isinstance(o, array.array): + abPayload.extend(o); + else: + reporter.fatal('sendMsg: unexpected payload type: %s (%s) (aoPayload=%s)' % (type(o), o, aoPayload)); + return None; + except: + reporter.fatalXcpt('sendMsg: screwed up the encoding code...'); + return None; + return self.sendMsgInt(sOpcode, cMsTimeout, abPayload); + + +class Session(TdTaskBase): + """ + A USB Test Service (UTS) client session. + """ + + def __init__(self, oTransport, cMsTimeout, cMsIdleFudge, fTryConnect = False): + """ + Construct a UTS session. + + This starts by connecting to the UTS and will enter the signalled state + when connected or the timeout has been reached. + """ + TdTaskBase.__init__(self, utils.getCallerName()); + self.oTransport = oTransport; + self.sStatus = ""; + self.cMsTimeout = 0; + self.fErr = True; # Whether to report errors as error. + self.msStart = 0; + self.oThread = None; + self.fnTask = self.taskDummy; + self.aTaskArgs = None; + self.oTaskRc = None; + self.t3oReply = (None, None, None); + self.fScrewedUpMsgState = False; + self.fTryConnect = fTryConnect; + + if not self.startTask(cMsTimeout, False, "connecting", self.taskConnect, (cMsIdleFudge,)): + raise base.GenError("startTask failed"); + + def __del__(self): + """Make sure to cancel the task when deleted.""" + self.cancelTask(); + + def toString(self): + return '<%s fnTask=%s, aTaskArgs=%s, sStatus=%s, oTaskRc=%s, cMsTimeout=%s,' \ + ' msStart=%s, fTryConnect=%s, fErr=%s, fScrewedUpMsgState=%s, t3oReply=%s oTransport=%s, oThread=%s>' \ + % (TdTaskBase.toString(self), self.fnTask, self.aTaskArgs, self.sStatus, self.oTaskRc, self.cMsTimeout, + self.msStart, self.fTryConnect, self.fErr, self.fScrewedUpMsgState, self.t3oReply, self.oTransport, self.oThread); + + def taskDummy(self): + """Place holder to catch broken state handling.""" + raise Exception(); + + def startTask(self, cMsTimeout, fIgnoreErrors, sStatus, fnTask, aArgs = ()): + """ + Kicks of a new task. + + cMsTimeout: The task timeout in milliseconds. Values less than + 500 ms will be adjusted to 500 ms. This means it is + OK to use negative value. + sStatus: The task status. + fnTask: The method that'll execute the task. + aArgs: Arguments to pass to fnTask. + + Returns True on success, False + error in log on failure. + """ + if not self.cancelTask(): + reporter.maybeErr(not fIgnoreErrors, 'utsclient.Session.startTask: failed to cancel previous task.'); + return False; + + # Change status and make sure we're the + self.lockTask(); + if self.sStatus != "": + self.unlockTask(); + reporter.maybeErr(not fIgnoreErrors, 'utsclient.Session.startTask: race.'); + return False; + self.sStatus = "setup"; + self.oTaskRc = None; + self.t3oReply = (None, None, None); + self.resetTaskLocked(); + self.unlockTask(); + + self.cMsTimeout = max(cMsTimeout, 500); + self.fErr = not fIgnoreErrors; + self.fnTask = fnTask; + self.aTaskArgs = aArgs; + self.oThread = threading.Thread(target=self.taskThread, args=(), name=('UTS-%s' % (sStatus))); + self.oThread.setDaemon(True); # pylint: disable=deprecated-method + self.msStart = base.timestampMilli(); + + self.lockTask(); + self.sStatus = sStatus; + self.unlockTask(); + self.oThread.start(); + + return True; + + def cancelTask(self, fSync = True): + """ + Attempts to cancel any pending tasks. + Returns success indicator (True/False). + """ + self.lockTask(); + + if self.sStatus == "": + self.unlockTask(); + return True; + if self.sStatus == "setup": + self.unlockTask(); + return False; + if self.sStatus == "cancelled": + self.unlockTask(); + return False; + + reporter.log('utsclient: cancelling "%s"...' % (self.sStatus)); + if self.sStatus == 'connecting': + self.oTransport.cancelConnect(); + + self.sStatus = "cancelled"; + oThread = self.oThread; + self.unlockTask(); + + if not fSync: + return False; + + oThread.join(61.0); + + if sys.version_info < (3, 9, 0): + # Removed since Python 3.9. + return oThread.isAlive(); # pylint: disable=no-member + return oThread.is_alive(); + + def taskThread(self): + """ + The task thread function. + This does some housekeeping activities around the real task method call. + """ + if not self.isCancelled(): + try: + fnTask = self.fnTask; + oTaskRc = fnTask(*self.aTaskArgs); + except: + reporter.fatalXcpt('taskThread', 15); + oTaskRc = None; + else: + reporter.log('taskThread: cancelled already'); + + self.lockTask(); + + reporter.log('taskThread: signalling task with status "%s", oTaskRc=%s' % (self.sStatus, oTaskRc)); + self.oTaskRc = oTaskRc; + self.oThread = None; + self.sStatus = ''; + self.signalTaskLocked(); + + self.unlockTask(); + return None; + + def isCancelled(self): + """Internal method for checking if the task has been cancelled.""" + self.lockTask(); + sStatus = self.sStatus; + self.unlockTask(); + if sStatus == "cancelled": + return True; + return False; + + def hasTimedOut(self): + """Internal method for checking if the task has timed out or not.""" + cMsLeft = self.getMsLeft(); + if cMsLeft <= 0: + return True; + return False; + + def getMsLeft(self, cMsMin = 0, cMsMax = -1): + """Gets the time left until the timeout.""" + cMsElapsed = base.timestampMilli() - self.msStart; + if cMsElapsed < 0: + return cMsMin; + cMsLeft = self.cMsTimeout - cMsElapsed; + if cMsLeft <= cMsMin: + return cMsMin; + if cMsLeft > cMsMax > 0: + return cMsMax + return cMsLeft; + + def recvReply(self, cMsTimeout = None, fNoDataOk = False): + """ + Wrapper for TransportBase.recvMsg that stashes the response away + so the client can inspect it later on. + """ + if cMsTimeout is None: + cMsTimeout = self.getMsLeft(500); + cbMsg, sOpcode, abPayload = self.oTransport.recvMsg(cMsTimeout, fNoDataOk); + self.lockTask(); + self.t3oReply = (cbMsg, sOpcode, abPayload); + self.unlockTask(); + return (cbMsg, sOpcode, abPayload); + + def recvAck(self, fNoDataOk = False): + """ + Receives an ACK or error response from the UTS. + + Returns True on success. + Returns False on timeout or transport error. + Returns (sOpcode, sDetails) tuple on failure. The opcode is stripped + and there are always details of some sort or another. + """ + cbMsg, sOpcode, abPayload = self.recvReply(None, fNoDataOk); + if cbMsg is None: + return False; + sOpcode = sOpcode.strip() + if sOpcode == "ACK": + return True; + return (sOpcode, getSZ(abPayload, 16, sOpcode)); + + def recvAckLogged(self, sCommand, fNoDataOk = False): + """ + Wrapper for recvAck and logging. + Returns True on success (ACK). + Returns False on time, transport error and errors signalled by UTS. + """ + rc = self.recvAck(fNoDataOk); + if rc is not True and not fNoDataOk: + if rc is False: + reporter.maybeErr(self.fErr, 'recvAckLogged: %s transport error' % (sCommand)); + else: + reporter.maybeErr(self.fErr, 'recvAckLogged: %s response was %s: %s' % (sCommand, rc[0], rc[1])); + rc = False; + return rc; + + def recvTrueFalse(self, sCommand): + """ + Receives a TRUE/FALSE response from the UTS. + Returns True on TRUE, False on FALSE and None on error/other (logged). + """ + cbMsg, sOpcode, abPayload = self.recvReply(); + if cbMsg is None: + reporter.maybeErr(self.fErr, 'recvAckLogged: %s transport error' % (sCommand)); + return None; + + sOpcode = sOpcode.strip() + if sOpcode == "TRUE": + return True; + if sOpcode == "FALSE": + return False; + reporter.maybeErr(self.fErr, 'recvAckLogged: %s response was %s: %s' % \ + (sCommand, sOpcode, getSZ(abPayload, 16, sOpcode))); + return None; + + def sendMsg(self, sOpcode, aoPayload = (), cMsTimeout = None): + """ + Wrapper for TransportBase.sendMsg that inserts the correct timeout. + """ + if cMsTimeout is None: + cMsTimeout = self.getMsLeft(500); + return self.oTransport.sendMsg(sOpcode, cMsTimeout, aoPayload); + + def asyncToSync(self, fnAsync, *aArgs): + """ + Wraps an asynchronous task into a synchronous operation. + + Returns False on failure, task return status on success. + """ + rc = fnAsync(*aArgs); + if rc is False: + reporter.log2('asyncToSync(%s): returns False (#1)' % (fnAsync)); + return rc; + + rc = self.waitForTask(self.cMsTimeout + 5000); + if rc is False: + reporter.maybeErrXcpt(self.fErr, 'asyncToSync: waitForTask failed...'); + self.cancelTask(); + #reporter.log2('asyncToSync(%s): returns False (#2)' % (fnAsync, rc)); + return False; + + rc = self.getResult(); + #reporter.log2('asyncToSync(%s): returns %s' % (fnAsync, rc)); + return rc; + + # + # Connection tasks. + # + + def taskConnect(self, cMsIdleFudge): + """Tries to connect to the UTS""" + while not self.isCancelled(): + reporter.log2('taskConnect: connecting ...'); + rc = self.oTransport.connect(self.getMsLeft(500)); + if rc is True: + reporter.log('taskConnect: succeeded'); + return self.taskGreet(cMsIdleFudge); + if rc is None: + reporter.log2('taskConnect: unable to connect'); + return None; + if self.hasTimedOut(): + reporter.log2('taskConnect: timed out'); + if not self.fTryConnect: + reporter.maybeErr(self.fErr, 'taskConnect: timed out'); + return False; + time.sleep(self.getMsLeft(1, 1000) / 1000.0); + if not self.fTryConnect: + reporter.maybeErr(self.fErr, 'taskConnect: cancelled'); + return False; + + def taskGreet(self, cMsIdleFudge): + """Greets the UTS""" + sHostname = socket.gethostname().lower(); + cbFill = 68 - len(sHostname) - 1; + rc = self.sendMsg("HOWDY", ((1 << 16) | 0, 0x1, len(sHostname), sHostname, zeroByteArray(cbFill))); + if rc is True: + rc = self.recvAckLogged("HOWDY", self.fTryConnect); + if rc is True: + while cMsIdleFudge > 0: + cMsIdleFudge -= 1000; + time.sleep(1); + else: + self.oTransport.disconnect(self.fTryConnect); + return rc; + + def taskBye(self): + """Says goodbye to the UTS""" + rc = self.sendMsg("BYE"); + if rc is True: + rc = self.recvAckLogged("BYE"); + self.oTransport.disconnect(); + return rc; + + # + # Gadget tasks. + # + + def taskGadgetCreate(self, iGadgetType, iGadgetAccess, lstCfg = None): + """Creates a new gadget on UTS""" + cCfgItems = 0; + if lstCfg is not None: + cCfgItems = len(lstCfg); + fRc = self.sendMsg("GDGTCRT", (iGadgetType, iGadgetAccess, cCfgItems, 0, cfgListToByteArray(lstCfg))); + if fRc is True: + fRc = self.recvAckLogged("GDGTCRT"); + return fRc; + + def taskGadgetDestroy(self, iGadgetId): + """Destroys the given gadget handle on UTS""" + fRc = self.sendMsg("GDGTDTOR", (iGadgetId, zeroByteArray(12))); + if fRc is True: + fRc = self.recvAckLogged("GDGTDTOR"); + return fRc; + + def taskGadgetConnect(self, iGadgetId): + """Connects the given gadget handle on UTS""" + fRc = self.sendMsg("GDGTCNCT", (iGadgetId, zeroByteArray(12))); + if fRc is True: + fRc = self.recvAckLogged("GDGTCNCT"); + return fRc; + + def taskGadgetDisconnect(self, iGadgetId): + """Disconnects the given gadget handle from UTS""" + fRc = self.sendMsg("GDGTDCNT", (iGadgetId, zeroByteArray(12))); + if fRc is True: + fRc = self.recvAckLogged("GDGTDCNT"); + return fRc; + + # + # Public methods - generic task queries + # + + def isSuccess(self): + """Returns True if the task completed successfully, otherwise False.""" + self.lockTask(); + sStatus = self.sStatus; + oTaskRc = self.oTaskRc; + self.unlockTask(); + if sStatus != "": + return False; + if oTaskRc is False or oTaskRc is None: + return False; + return True; + + def getResult(self): + """ + Returns the result of a completed task. + Returns None if not completed yet or no previous task. + """ + self.lockTask(); + sStatus = self.sStatus; + oTaskRc = self.oTaskRc; + self.unlockTask(); + if sStatus != "": + return None; + return oTaskRc; + + def getLastReply(self): + """ + Returns the last reply three-tuple: cbMsg, sOpcode, abPayload. + Returns a None, None, None three-tuple if there was no last reply. + """ + self.lockTask(); + t3oReply = self.t3oReply; + self.unlockTask(); + return t3oReply; + + # + # Public methods - connection. + # + + def asyncDisconnect(self, cMsTimeout = 30000, fIgnoreErrors = False): + """ + Initiates a disconnect task. + + Returns True on success, False on failure (logged). + + The task returns True on success and False on failure. + """ + return self.startTask(cMsTimeout, fIgnoreErrors, "bye", self.taskBye); + + def syncDisconnect(self, cMsTimeout = 30000, fIgnoreErrors = False): + """Synchronous version.""" + return self.asyncToSync(self.asyncDisconnect, cMsTimeout, fIgnoreErrors); + + # + # Public methods - gadget API + # + + def asyncGadgetCreate(self, iGadgetType, iGadgetAccess, lstCfg = None, cMsTimeout = 30000, fIgnoreErrors = False): + """ + Initiates a gadget create task. + + Returns True on success, False on failure (logged). + + The task returns True on success and False on failure. + """ + return self.startTask(cMsTimeout, fIgnoreErrors, "GadgetCreate", self.taskGadgetCreate, \ + (iGadgetType, iGadgetAccess, lstCfg)); + + def syncGadgetCreate(self, iGadgetType, iGadgetAccess, lstCfg = None, cMsTimeout = 30000, fIgnoreErrors = False): + """Synchronous version.""" + return self.asyncToSync(self.asyncGadgetCreate, iGadgetType, iGadgetAccess, lstCfg, cMsTimeout, fIgnoreErrors); + + def asyncGadgetDestroy(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False): + """ + Initiates a gadget destroy task. + + Returns True on success, False on failure (logged). + + The task returns True on success and False on failure. + """ + return self.startTask(cMsTimeout, fIgnoreErrors, "GadgetDestroy", self.taskGadgetDestroy, \ + (iGadgetId, )); + + def syncGadgetDestroy(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False): + """Synchronous version.""" + return self.asyncToSync(self.asyncGadgetDestroy, iGadgetId, cMsTimeout, fIgnoreErrors); + + def asyncGadgetConnect(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False): + """ + Initiates a gadget connect task. + + Returns True on success, False on failure (logged). + + The task returns True on success and False on failure. + """ + return self.startTask(cMsTimeout, fIgnoreErrors, "GadgetConnect", self.taskGadgetConnect, \ + (iGadgetId, )); + + def syncGadgetConnect(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False): + """Synchronous version.""" + return self.asyncToSync(self.asyncGadgetConnect, iGadgetId, cMsTimeout, fIgnoreErrors); + + def asyncGadgetDisconnect(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False): + """ + Initiates a gadget disconnect task. + + Returns True on success, False on failure (logged). + + The task returns True on success and False on failure. + """ + return self.startTask(cMsTimeout, fIgnoreErrors, "GadgetDisconnect", self.taskGadgetDisconnect, \ + (iGadgetId, )); + + def syncGadgetDisconnect(self, iGadgetId, cMsTimeout = 30000, fIgnoreErrors = False): + """Synchronous version.""" + return self.asyncToSync(self.asyncGadgetDisconnect, iGadgetId, cMsTimeout, fIgnoreErrors); + + +class TransportTcp(TransportBase): + """ + TCP transport layer for the UTS client session class. + """ + + def __init__(self, sHostname, uPort): + """ + Save the parameters. The session will call us back to make the + connection later on its worker thread. + """ + TransportBase.__init__(self, utils.getCallerName()); + self.sHostname = sHostname; + self.uPort = uPort if uPort is not None else 6042; + self.oSocket = None; + self.oWakeupW = None; + self.oWakeupR = None; + self.fConnectCanceled = False; + self.fIsConnecting = False; + self.oCv = threading.Condition(); + self.abReadAhead = array.array('B'); + + def toString(self): + return '<%s sHostname=%s, uPort=%s, oSocket=%s,'\ + ' fConnectCanceled=%s, fIsConnecting=%s, oCv=%s, abReadAhead=%s>' \ + % (TransportBase.toString(self), self.sHostname, self.uPort, self.oSocket, + self.fConnectCanceled, self.fIsConnecting, self.oCv, self.abReadAhead); + + def __isInProgressXcpt(self, oXcpt): + """ In progress exception? """ + try: + if isinstance(oXcpt, socket.error): + try: + if oXcpt[0] == errno.EINPROGRESS: + return True; + except: pass; + try: + if oXcpt[0] == errno.EWOULDBLOCK: + return True; + if utils.getHostOs() == 'win' and oXcpt[0] == errno.WSAEWOULDBLOCK: # pylint: disable=no-member + return True; + except: pass; + except: + pass; + return False; + + def __isWouldBlockXcpt(self, oXcpt): + """ Would block exception? """ + try: + if isinstance(oXcpt, socket.error): + try: + if oXcpt[0] == errno.EWOULDBLOCK: + return True; + except: pass; + try: + if oXcpt[0] == errno.EAGAIN: + return True; + except: pass; + except: + pass; + return False; + + def __isConnectionReset(self, oXcpt): + """ Connection reset by Peer or others. """ + try: + if isinstance(oXcpt, socket.error): + try: + if oXcpt[0] == errno.ECONNRESET: + return True; + except: pass; + try: + if oXcpt[0] == errno.ENETRESET: + return True; + except: pass; + except: + pass; + return False; + + def _closeWakeupSockets(self): + """ Closes the wakup sockets. Caller should own the CV. """ + oWakeupR = self.oWakeupR; + self.oWakeupR = None; + if oWakeupR is not None: + oWakeupR.close(); + + oWakeupW = self.oWakeupW; + self.oWakeupW = None; + if oWakeupW is not None: + oWakeupW.close(); + + return None; + + def cancelConnect(self): + # This is bad stuff. + self.oCv.acquire(); + reporter.log2('TransportTcp::cancelConnect: fIsConnecting=%s oSocket=%s' % (self.fIsConnecting, self.oSocket)); + self.fConnectCanceled = True; + if self.fIsConnecting: + oSocket = self.oSocket; + self.oSocket = None; + if oSocket is not None: + reporter.log2('TransportTcp::cancelConnect: closing the socket'); + oSocket.close(); + + oWakeupW = self.oWakeupW; + self.oWakeupW = None; + if oWakeupW is not None: + reporter.log2('TransportTcp::cancelConnect: wakeup call'); + try: oWakeupW.send('cancelled!\n'); + except: reporter.logXcpt(); + try: oWakeupW.shutdown(socket.SHUT_WR); + except: reporter.logXcpt(); + oWakeupW.close(); + self.oCv.release(); + + def _connectAsClient(self, oSocket, oWakeupR, cMsTimeout): + """ Connects to the UTS server as client. """ + + # Connect w/ timeouts. + rc = None; + try: + oSocket.connect((self.sHostname, self.uPort)); + rc = True; + except socket.error as oXcpt: + iRc = oXcpt.errno; + if self.__isInProgressXcpt(oXcpt): + # Do the actual waiting. + reporter.log2('TransportTcp::connect: operation in progress (%s)...' % (oXcpt,)); + try: + ttRc = select.select([oWakeupR], [oSocket], [oSocket, oWakeupR], cMsTimeout / 1000.0); + if len(ttRc[1]) + len(ttRc[2]) == 0: + raise socket.error(errno.ETIMEDOUT, 'select timed out'); + iRc = oSocket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR); + rc = iRc == 0; + except socket.error as oXcpt2: + iRc = oXcpt2.errno; + except: + iRc = -42; + reporter.fatalXcpt('socket.select() on connect failed'); + + if rc is True: + pass; + elif iRc in (errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.EINTR, errno.ENETDOWN, errno.ENETUNREACH, errno.ETIMEDOUT): + rc = False; # try again. + else: + if iRc != errno.EBADF or not self.fConnectCanceled: + reporter.fatalXcpt('socket.connect((%s,%s)) failed; iRc=%s' % (self.sHostname, self.uPort, iRc)); + reporter.log2('TransportTcp::connect: rc=%s iRc=%s' % (rc, iRc)); + except: + reporter.fatalXcpt('socket.connect((%s,%s)) failed' % (self.sHostname, self.uPort)); + return rc; + + + def connect(self, cMsTimeout): + # Create a non-blocking socket. + reporter.log2('TransportTcp::connect: cMsTimeout=%s sHostname=%s uPort=%s' % (cMsTimeout, self.sHostname, self.uPort)); + try: + oSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0); + except: + reporter.fatalXcpt('socket.socket() failed'); + return None; + try: + oSocket.setblocking(0); + except: + oSocket.close(); + reporter.fatalXcpt('socket.socket() failed'); + return None; + + # Create wakeup socket pair for unix (select doesn't wake up on socket close on Linux). + oWakeupR = None; + oWakeupW = None; + if hasattr(socket, 'socketpair'): + try: (oWakeupR, oWakeupW) = socket.socketpair(); # pylint: disable=no-member + except: reporter.logXcpt('socket.socketpair() failed'); + + # Update the state. + self.oCv.acquire(); + rc = None; + if not self.fConnectCanceled: + self.oSocket = oSocket; + self.oWakeupW = oWakeupW; + self.oWakeupR = oWakeupR; + self.fIsConnecting = True; + self.oCv.release(); + + # Try connect. + if oWakeupR is None: + oWakeupR = oSocket; # Avoid select failure. + rc = self._connectAsClient(oSocket, oWakeupR, cMsTimeout); + oSocket = None; + + # Update the state and cleanup on failure/cancel. + self.oCv.acquire(); + if rc is True and self.fConnectCanceled: + rc = False; + self.fIsConnecting = False; + + if rc is not True: + if self.oSocket is not None: + self.oSocket.close(); + self.oSocket = None; + self._closeWakeupSockets(); + self.oCv.release(); + + reporter.log2('TransportTcp::connect: returning %s' % (rc,)); + return rc; + + def disconnect(self, fQuiet = False): + if self.oSocket is not None: + self.abReadAhead = array.array('B'); + + # Try a shutting down the socket gracefully (draining it). + try: + self.oSocket.shutdown(socket.SHUT_WR); + except: + if not fQuiet: + reporter.error('shutdown(SHUT_WR)'); + try: + self.oSocket.setblocking(0); # just in case it's not set. + sData = "1"; + while sData: + sData = self.oSocket.recv(16384); + except: + pass; + + # Close it. + self.oCv.acquire(); + try: self.oSocket.setblocking(1); + except: pass; + self.oSocket.close(); + self.oSocket = None; + else: + self.oCv.acquire(); + self._closeWakeupSockets(); + self.oCv.release(); + + def sendBytes(self, abBuf, cMsTimeout): + if self.oSocket is None: + reporter.error('TransportTcp.sendBytes: No connection.'); + return False; + + # Try send it all. + try: + cbSent = self.oSocket.send(abBuf); + if cbSent == len(abBuf): + return True; + except Exception as oXcpt: + if not self.__isWouldBlockXcpt(oXcpt): + reporter.errorXcpt('TranportTcp.sendBytes: %s bytes' % (len(abBuf))); + return False; + cbSent = 0; + + # Do a timed send. + msStart = base.timestampMilli(); + while True: + cMsElapsed = base.timestampMilli() - msStart; + if cMsElapsed > cMsTimeout: + reporter.error('TranportTcp.sendBytes: %s bytes timed out (1)' % (len(abBuf))); + break; + + # wait. + try: + ttRc = select.select([], [self.oSocket], [self.oSocket], (cMsTimeout - cMsElapsed) / 1000.0); + if ttRc[2] and not ttRc[1]: + reporter.error('TranportTcp.sendBytes: select returned with exception'); + break; + if not ttRc[1]: + reporter.error('TranportTcp.sendBytes: %s bytes timed out (2)' % (len(abBuf))); + break; + except: + reporter.errorXcpt('TranportTcp.sendBytes: select failed'); + break; + + # Try send more. + try: + cbSent += self.oSocket.send(abBuf[cbSent:]); + if cbSent == len(abBuf): + return True; + except Exception as oXcpt: + if not self.__isWouldBlockXcpt(oXcpt): + reporter.errorXcpt('TranportTcp.sendBytes: %s bytes' % (len(abBuf))); + break; + + return False; + + def __returnReadAheadBytes(self, cb): + """ Internal worker for recvBytes. """ + assert(len(self.abReadAhead) >= cb); + abRet = self.abReadAhead[:cb]; + self.abReadAhead = self.abReadAhead[cb:]; + return abRet; + + def recvBytes(self, cb, cMsTimeout, fNoDataOk): + if self.oSocket is None: + reporter.error('TransportTcp.recvBytes(%s,%s): No connection.' % (cb, cMsTimeout)); + return None; + + # Try read in some more data without bothering with timeout handling first. + if len(self.abReadAhead) < cb: + try: + abBuf = self.oSocket.recv(cb - len(self.abReadAhead)); + if abBuf: + self.abReadAhead.extend(array.array('B', abBuf)); + except Exception as oXcpt: + if not self.__isWouldBlockXcpt(oXcpt): + reporter.errorXcpt('TranportTcp.recvBytes: 0/%s bytes' % (cb,)); + return None; + + if len(self.abReadAhead) >= cb: + return self.__returnReadAheadBytes(cb); + + # Timeout loop. + msStart = base.timestampMilli(); + while True: + cMsElapsed = base.timestampMilli() - msStart; + if cMsElapsed > cMsTimeout: + if not fNoDataOk or self.abReadAhead: + reporter.error('TranportTcp.recvBytes: %s/%s bytes timed out (1)' % (len(self.abReadAhead), cb)); + break; + + # Wait. + try: + ttRc = select.select([self.oSocket], [], [self.oSocket], (cMsTimeout - cMsElapsed) / 1000.0); + if ttRc[2] and not ttRc[0]: + reporter.error('TranportTcp.recvBytes: select returned with exception'); + break; + if not ttRc[0]: + if not fNoDataOk or self.abReadAhead: + reporter.error('TranportTcp.recvBytes: %s/%s bytes timed out (2) fNoDataOk=%s' + % (len(self.abReadAhead), cb, fNoDataOk)); + break; + except: + reporter.errorXcpt('TranportTcp.recvBytes: select failed'); + break; + + # Try read more. + try: + abBuf = self.oSocket.recv(cb - len(self.abReadAhead)); + if not abBuf: + reporter.error('TranportTcp.recvBytes: %s/%s bytes (%s) - connection has been shut down' + % (len(self.abReadAhead), cb, fNoDataOk)); + self.disconnect(); + return None; + + self.abReadAhead.extend(array.array('B', abBuf)); + + except Exception as oXcpt: + reporter.log('recv => exception %s' % (oXcpt,)); + if not self.__isWouldBlockXcpt(oXcpt): + if not fNoDataOk or not self.__isConnectionReset(oXcpt) or self.abReadAhead: + reporter.errorXcpt('TranportTcp.recvBytes: %s/%s bytes (%s)' % (len(self.abReadAhead), cb, fNoDataOk)); + break; + + # Done? + if len(self.abReadAhead) >= cb: + return self.__returnReadAheadBytes(cb); + + #reporter.log('recv => None len(self.abReadAhead) -> %d' % (len(self.abReadAhead), )); + return None; + + def isConnectionOk(self): + if self.oSocket is None: + return False; + try: + ttRc = select.select([], [], [self.oSocket], 0.0); + if ttRc[2]: + return False; + + self.oSocket.send(array.array('B')); # send zero bytes. + except: + return False; + return True; + + def isRecvPending(self, cMsTimeout = 0): + try: + ttRc = select.select([self.oSocket], [], [], cMsTimeout / 1000.0); + if not ttRc[0]: + return False; + except: + pass; + return True; + + +class UsbGadget(object): + """ + USB Gadget control class using the USBT Test Service to talk to the external + board behaving like a USB device. + """ + + def __init__(self): + self.oUtsSession = None; + self.sImpersonation = g_ksGadgetImpersonationInvalid; + self.idGadget = None; + self.iBusId = None; + self.iDevId = None; + self.iUsbIpPort = None; + + def clearImpersonation(self): + """ + Removes the current impersonation of the gadget. + """ + fRc = True; + + if self.idGadget is not None: + fRc = self.oUtsSession.syncGadgetDestroy(self.idGadget); + self.idGadget = None; + self.iBusId = None; + self.iDevId = None; + + return fRc; + + def disconnectUsb(self): + """ + Disconnects the USB gadget from the host. (USB connection not network + connection used for control) + """ + return self.oUtsSession.syncGadgetDisconnect(self.idGadget); + + def connectUsb(self): + """ + Connect the USB gadget to the host. + """ + return self.oUtsSession.syncGadgetConnect(self.idGadget); + + def impersonate(self, sImpersonation, fSuperSpeed = False): + """ + Impersonate a given device. + """ + + # Clear any previous impersonation + self.clearImpersonation(); + self.sImpersonation = sImpersonation; + + fRc = False; + if sImpersonation == g_ksGadgetImpersonationTest: + lstCfg = []; + if fSuperSpeed is True: + lstCfg.append( ('Gadget/SuperSpeed', g_kiGadgetCfgTypeBool, 'true') ); + fDone = self.oUtsSession.syncGadgetCreate(g_kiGadgetTypeTest, g_kiGadgetAccessUsbIp, lstCfg); + if fDone is True and self.oUtsSession.isSuccess(): + # Get the gadget ID. + _, _, abPayload = self.oUtsSession.getLastReply(); + + fRc = True; + self.idGadget = getU32(abPayload, 16); + self.iBusId = getU32(abPayload, 20); + self.iDevId = getU32(abPayload, 24); + else: + reporter.log('Invalid or unsupported impersonation'); + + return fRc; + + def getUsbIpPort(self): + """ + Returns the port the USB/IP server is listening on if requested, + None if USB/IP is not supported. + """ + return self.iUsbIpPort; + + def getGadgetBusAndDevId(self): + """ + Returns the bus ad device ID of the gadget as a tuple. + """ + return (self.iBusId, self.iDevId); + + def connectTo(self, cMsTimeout, sHostname, uPort = None, fUsbIpSupport = True, cMsIdleFudge = 0, fTryConnect = False): + """ + Connects to the specified target device. + Returns True on Success. + Returns False otherwise. + """ + fRc = True; + + # @todo + if fUsbIpSupport is False: + return False; + + reporter.log2('openTcpSession(%s, %s, %s, %s)' % \ + (cMsTimeout, sHostname, uPort, cMsIdleFudge)); + try: + oTransport = TransportTcp(sHostname, uPort); + self.oUtsSession = Session(oTransport, cMsTimeout, cMsIdleFudge, fTryConnect); + + if self.oUtsSession is not None: + fDone = self.oUtsSession.waitForTask(30*1000); + reporter.log('connect: waitForTask -> %s, result %s' % (fDone, self.oUtsSession.getResult())); + if fDone is True and self.oUtsSession.isSuccess(): + # Parse the reply. + _, _, abPayload = self.oUtsSession.getLastReply(); + + if getU32(abPayload, 20) is g_kiGadgetAccessUsbIp: + fRc = True; + self.iUsbIpPort = getU32(abPayload, 24); + else: + reporter.log('Gadget doesn\'t support access over USB/IP despite being requested'); + fRc = False; + else: + fRc = False; + else: + fRc = False; + except: + reporter.errorXcpt(None, 15); + return False; + + return fRc; + + def disconnectFrom(self): + """ + Disconnects from the target device. + """ + fRc = True; + + self.clearImpersonation(); + if self.oUtsSession is not None: + fRc = self.oUtsSession.syncDisconnect(); + + return fRc; |