Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ClearFeature(ENDPOINT_HALT) #266

Merged
merged 1 commit into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions applets/clear_endpoint_halt_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
#!/usr/bin/env python3
#
# This file is part of LUNA.
#
# Copyright (c) 2024 Great Scott Gadgets <[email protected]>
# SPDX-License-Identifier: BSD-3-Clause

import logging
import os
import time
import usb1

from amaranth import Elaboratable, Module, Signal

from luna import top_level_cli, configure_default_logging
from luna.usb2 import USBDevice, USBStreamInEndpoint, USBStreamOutEndpoint
from luna.gateware.stream.generator import StreamSerializer
from luna.gateware.usb.request.control import ControlRequestHandler
from luna.gateware.usb.stream import USBInStreamInterface

from usb_protocol.types import USBRequestRecipient, USBRequestType
from usb_protocol.emitters import DeviceDescriptorCollection

# use pid.codes Test PID
VID = 0x1209
PID = 0x0001

BULK_ENDPOINT_NUMBER = 1
MAX_BULK_PACKET_SIZE = 512

COUNTER_MAX = 251
GET_OUT_COUNTER_VALID = 0

out_counter_valid = Signal(reset=1)

class VendorRequestHandler(ControlRequestHandler):

REQUEST_SET_LEDS = 0

def elaborate(self, platform):
m = Module()

interface = self.interface
setup = self.interface.setup

# Transmitter for small-constant-response requests
m.submodules.transmitter = transmitter = \
StreamSerializer(data_length=1, domain="usb", stream_type=USBInStreamInterface, max_length_width=1)
#
# Vendor request handlers.
with m.FSM(domain="usb"):
with m.State('IDLE'):
vendor = setup.type == USBRequestType.VENDOR
with m.If(
setup.received & \
(setup.type == USBRequestType.VENDOR) & \
(setup.recipient == USBRequestRecipient.INTERFACE) & \
(setup.index == 0)):
with m.Switch(setup.request):
with m.Case(GET_OUT_COUNTER_VALID):
m.d.comb += interface.claim.eq(1)
m.next = 'GET_OUT_COUNTER_VALID'
pass

with m.State('GET_OUT_COUNTER_VALID'):
m.d.comb += interface.claim.eq(1)
self.handle_simple_data_request(m, transmitter, out_counter_valid, length=1)

return m


class ClearHaltTestDevice(Elaboratable):


def create_descriptors(self):

descriptors = DeviceDescriptorCollection()

with descriptors.DeviceDescriptor() as d:
d.idVendor = VID
d.idProduct = PID

d.iManufacturer = "LUNA"
d.iProduct = "Clear Endpoint Halt Test"
d.iSerialNumber = "no serial"

d.bNumConfigurations = 1


with descriptors.ConfigurationDescriptor() as c:

with c.InterfaceDescriptor() as i:
i.bInterfaceNumber = 0

with i.EndpointDescriptor() as e:
e.bEndpointAddress = 0x80 | BULK_ENDPOINT_NUMBER
e.wMaxPacketSize = MAX_BULK_PACKET_SIZE

with i.EndpointDescriptor() as e:
e.bEndpointAddress = BULK_ENDPOINT_NUMBER
e.wMaxPacketSize = MAX_BULK_PACKET_SIZE


return descriptors


def elaborate(self, platform):
m = Module()

m.submodules.car = platform.clock_domain_generator()

ulpi = platform.request(platform.default_usb_connection)
m.submodules.usb = usb = USBDevice(bus=ulpi)

descriptors = self.create_descriptors()
control_ep = usb.add_standard_control_endpoint(descriptors)

control_ep.add_request_handler(VendorRequestHandler())

stream_in_ep = USBStreamInEndpoint(
endpoint_number=BULK_ENDPOINT_NUMBER,
max_packet_size=MAX_BULK_PACKET_SIZE
)
usb.add_endpoint(stream_in_ep)

stream_out_ep = USBStreamOutEndpoint(
endpoint_number=BULK_ENDPOINT_NUMBER,
max_packet_size=MAX_BULK_PACKET_SIZE
)
usb.add_endpoint(stream_out_ep)

# Generate a counter on the IN endpoint.
in_counter = Signal(8)
with m.If(stream_in_ep.stream.ready):
m.d.usb += in_counter.eq(in_counter + 1)
with m.If(in_counter == COUNTER_MAX):
m.d.usb += in_counter.eq(0)

# Expect a counter on the OUT endpoint, and verify that it is contiguous.
prev_out_counter = Signal(8, reset=COUNTER_MAX)
with m.If(stream_out_ep.stream.valid):
out_counter = stream_out_ep.stream.payload
counter_increase = out_counter == (prev_out_counter + 1)
counter_wrap = (out_counter == 0) & (prev_out_counter == COUNTER_MAX)
with m.If(~counter_increase & ~counter_wrap):
m.d.usb += out_counter_valid.eq(0)

m.d.usb += prev_out_counter.eq(out_counter)

m.d.comb += [
stream_in_ep.stream.valid .eq(1),
stream_in_ep.stream.payload .eq(in_counter),

stream_out_ep.stream.ready .eq(1),
]

# Connect our device as a high speed device by default.
m.d.comb += [
usb.connect .eq(1),
usb.full_speed_only .eq(1 if os.getenv('LUNA_FULL_ONLY') else 0),
]

return m

def test_clear_halt():
with usb1.USBContext() as context:
device = context.openByVendorIDAndProductID(VID, PID)

# Read the first packet which should have a DATA0 PID, next we expect DATA1.
packet = device.bulkRead(BULK_ENDPOINT_NUMBER, MAX_BULK_PACKET_SIZE)
# Send clear halt, this resets both sides to DATA0.
device.clearHalt(usb1.ENDPOINT_IN | BULK_ENDPOINT_NUMBER)
# Read another packet. If the PID doesn't match what we epxect,
# then the host will assume it was a retransmission of the last one and drop it.
packet += device.bulkRead(BULK_ENDPOINT_NUMBER, MAX_BULK_PACKET_SIZE)

# Check that the counter is contiguous across all received data, making sure we didn't drop a packet.
for i in range(1, len(packet)):
if packet[i] == packet[i-1] + 1:
pass
elif packet[i] == 0 and packet[i-1] == COUNTER_MAX:
pass
else:
print(f"IN test fail {i} {packet[i]} {packet[i-1]}")
return

print("IN OK")

# Generate three packets worth of counter data, the gateware will verify that it is contiguous.
data = bytes(i % (COUNTER_MAX+1) for i in range(MAX_BULK_PACKET_SIZE*3))
# Send DATA0, device should expect DATA1 next.
device.bulkWrite(BULK_ENDPOINT_NUMBER, data[:MAX_BULK_PACKET_SIZE])
# Reset both sides to DATA0.
device.clearHalt(usb1.ENDPOINT_OUT | BULK_ENDPOINT_NUMBER)
# Send two packets. If the first packet doesn't match,
# it'll be dropped and another is required to let the gateware check the counter.
device.bulkWrite(BULK_ENDPOINT_NUMBER, data[MAX_BULK_PACKET_SIZE:])

# Read back the out_counter_valid register to check for success.
request_type = usb1.REQUEST_TYPE_VENDOR | usb1.RECIPIENT_INTERFACE | usb1.ENDPOINT_IN
if device.controlRead(request_type, GET_OUT_COUNTER_VALID, 0, 0, 1)[0] == 1:
print("OUT OK")
else:
print("OUT FAIL")


if __name__ == "__main__":
configure_default_logging()

# If our environment is suggesting we rerun tests without rebuilding, do so.
if os.getenv('LUNA_RERUN_TEST'):
logging.info("Running speed test without rebuilding...")

# Otherwise, rebuild.
else:
device = top_level_cli(ClearHaltTestDevice)

# Give the device a moment to connect.
if device is not None:
logging.info("Giving the device time to connect...")
time.sleep(5)

test_clear_halt()
28 changes: 27 additions & 1 deletion luna/gateware/usb/request/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from amaranth import *
from amaranth.hdl.ast import Value, Const
from usb_protocol.types import USBStandardRequests, USBRequestType
from usb_protocol.types import USBStandardFeatures, USBStandardRequests, USBRequestRecipient, USBRequestType
from usb_protocol.emitters import DeviceDescriptorCollection

from ..usb2.request import RequestHandlerInterface, USBRequestHandler
Expand Down Expand Up @@ -139,6 +139,8 @@ def elaborate(self, platform):

with m.Case(USBStandardRequests.GET_STATUS):
m.next = 'GET_STATUS'
with m.Case(USBStandardRequests.CLEAR_FEATURE):
m.next = 'CLEAR_FEATURE'
with m.Case(USBStandardRequests.SET_ADDRESS):
m.next = 'SET_ADDRESS'
with m.Case(USBStandardRequests.SET_CONFIGURATION):
Expand All @@ -158,6 +160,30 @@ def elaborate(self, platform):
# TODO: copy the remote wakeup and bus-powered attributes from bmAttributes of the relevant descriptor?
self.handle_simple_data_request(m, transmitter, 0, length=2)

with m.State('CLEAR_FEATURE'):
# Provide an response to the STATUS stage.
with m.If(interface.status_requested):

# If our stall condition is met, stall; otherwise, send a ZLP [USB 8.5.3].
# For now, we only implement clearing ENDPOINT_HALT.
stall_condition = \
(setup.recipient != USBRequestRecipient.ENDPOINT) | \
(setup.value != USBStandardFeatures.ENDPOINT_HALT)
with m.If(stall_condition):
m.d.comb += handshake_generator.stall.eq(1)
with m.Else():
m.d.comb += self.send_zlp()

# Accept the relevant value after the packet is ACK'd...
with m.If(interface.handshakes_in.ack):
m.d.comb += [
interface.clear_endpoint_halt.enable .eq(1),
interface.clear_endpoint_halt.direction.eq(setup.index[7]),
interface.clear_endpoint_halt.number .eq(setup.index[0:4]),
]

# ... and then return to idle.
m.next = 'IDLE'
miek marked this conversation as resolved.
Show resolved Hide resolved

# SET_ADDRESS -- The host is trying to assign us an address.
with m.State('SET_ADDRESS'):
Expand Down
2 changes: 2 additions & 0 deletions luna/gateware/usb/usb2/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def elaborate(self, platform):
interface.address_changed .eq(request_handler.address_changed),
interface.new_address .eq(request_handler.new_address),

interface.clear_endpoint_halt_out .eq(request_handler.clear_endpoint_halt),

request_handler.active_config .eq(interface.active_config),
interface.config_changed .eq(request_handler.config_changed),
interface.new_config .eq(request_handler.new_config),
Expand Down
10 changes: 10 additions & 0 deletions luna/gateware/usb/usb2/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .packet import DataCRCInterface, InterpacketTimerInterface, TokenDetectorInterface
from .packet import HandshakeExchangeInterface
from .request import ClearEndpointHaltInterface
from ..stream import USBInStreamInterface, USBOutStreamInterface
from ...utils.bus import OneHotMultiplexer

Expand Down Expand Up @@ -90,6 +91,9 @@ def __init__(self):
self.config_changed = Signal()
self.new_config = Signal(8)

self.clear_endpoint_halt_out = Signal(ClearEndpointHaltInterface)
self.clear_endpoint_halt_in = Signal(ClearEndpointHaltInterface)

self.rx = USBOutStreamInterface()
self.rx_complete = Signal()
self.rx_ready_for_response = Signal()
Expand Down Expand Up @@ -213,6 +217,8 @@ def elaborate(self, platform):
shared.handshakes_in .connect(interface.handshakes_in),
shared.tokenizer .connect(interface.tokenizer),

interface.clear_endpoint_halt_in .eq(shared.clear_endpoint_halt_out),

# Rx interface.
shared.rx .connect(interface.rx),
interface.rx_complete .eq(shared.rx_complete),
Expand Down Expand Up @@ -259,6 +265,10 @@ def elaborate(self, platform):
# ... and our timer start signals.
self.or_join_interface_signals(m, lambda interface : interface.timer.start)

self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.enable)
self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.direction)
self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.number)

# Finally, connect up our transmit PID select.
conditional = m.If

Expand Down
18 changes: 18 additions & 0 deletions luna/gateware/usb/usb2/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def elaborate(self, platform):
# Create our transfer manager, which will be used to sequence packet transfers for our stream.
m.submodules.tx_manager = tx_manager = USBInTransferManager(self._max_packet_size)

# Check there has been a ClearFeature(ENDPOINT_HALT) request address to this endpoint.
clear_endpoint_halt = \
interface.clear_endpoint_halt_in.enable & \
interface.clear_endpoint_halt_in.direction & \
(interface.clear_endpoint_halt_in.number == self._endpoint_number)
m.d.comb += [

# Always generate ZLPs; in order to pass along when stream packets terminate.
Expand All @@ -94,6 +99,9 @@ def elaborate(self, platform):
tx_manager.flush .eq(self.flush),
tx_manager.discard .eq(self.discard),

# ... and data-toggle reset on clear endpoint halt...
tx_manager.reset_sequence .eq(clear_endpoint_halt),

# ... and our output stream...
interface.tx .stream_eq(tx_manager.packet_stream),
interface.tx_pid_toggle .eq(tx_manager.data_pid),
Expand Down Expand Up @@ -414,6 +422,16 @@ def elaborate(self, platform):
with m.If(data_response_requested & data_accepted):
m.d.usb += expected_data_toggle.eq(~expected_data_toggle)

# If there has been a ClearFeature(ENDPOINT_HALT) request address to this endpoint...
clear_endpoint_halt = \
self.interface.clear_endpoint_halt_in.enable & \
~self.interface.clear_endpoint_halt_in.direction & \
(self.interface.clear_endpoint_halt_in.number == self._endpoint_number)

with m.If(clear_endpoint_halt):
# ... reset the expected data toggle.
m.d.usb += expected_data_toggle.eq(0)


return m

Loading