diff --git a/portal/client/nimbus_portal_client.nim b/portal/client/nimbus_portal_client.nim index f19084d05b..898026db9c 100644 --- a/portal/client/nimbus_portal_client.nim +++ b/portal/client/nimbus_portal_client.nim @@ -155,7 +155,7 @@ proc run(portalClient: PortalClient, config: PortalConf) {.raises: [CatchableErr # measure to easily identify & debug the clients used in the testnet. # Might make this into a, default off, cli option. localEnrFields = - {"c": enrClientInfoShort, portalVersionKey: SSZ.encode(localSupportedVersions)}, + {"c": enrClientInfoShort, portalEnrKey: rlp.encode(localPortalEnrField)}, bootstrapRecords = bootstrapRecords, previousRecord = previousEnr, bindIp = bindIp, diff --git a/portal/network/wire/portal_protocol.nim b/portal/network/wire/portal_protocol.nim index 8398ebe7e0..6787277503 100644 --- a/portal/network/wire/portal_protocol.nim +++ b/portal/network/wire/portal_protocol.nim @@ -335,7 +335,7 @@ func fromNodeStatus(T: type NodeAddResult, status: NodeStatus): T = of NodeStatus.Banned: T.Banned proc addNode*(p: PortalProtocol, node: Node): NodeAddResult = - if node.highestCommonPortalVersion(localSupportedVersions).isOk(): + if node.highestCommonPortalVersionAndChain(localPortalEnrField).isOk(): let status = p.routingTable.addNode(node) trace "Adding node to routing table", status, node NodeAddResult.fromNodeStatus(status) @@ -672,8 +672,8 @@ proc messageHandler( warn "No ENR found for node", srcId, srcUdpAddress return @[] - let _ = enr.highestCommonPortalVersion(localSupportedVersions).valueOr: - debug "No compatible protocol version found", error, srcId, srcUdpAddress + let _ = enr.highestCommonPortalVersionAndChain(localPortalEnrField).valueOr: + debug "Incompatible protocols", error, srcId, srcUdpAddress return @[] let decoded = decodeMessage(request) @@ -884,7 +884,7 @@ proc ping*( async: (raises: [CancelledError]) .} = # Fail if no common portal version is found - let _ = ?dst.highestCommonPortalVersion(localSupportedVersions) + let _ = ?dst.highestCommonPortalVersionAndChain(localPortalEnrField) if p.isBanned(dst.id): return err("destination node is banned") @@ -910,7 +910,7 @@ proc findNodes*( p: PortalProtocol, dst: Node, distances: seq[uint16] ): Future[PortalResult[seq[Node]]] {.async: (raises: [CancelledError]).} = # Fail if no common portal version is found - let _ = ?dst.highestCommonPortalVersion(localSupportedVersions) + let _ = ?dst.highestCommonPortalVersionAndChain(localPortalEnrField) if p.isBanned(dst.id): return err("destination node is banned") @@ -928,7 +928,7 @@ proc findContent*( p: PortalProtocol, dst: Node, contentKey: ContentKeyByteList ): Future[PortalResult[FoundContent]] {.async: (raises: [CancelledError]).} = # Fail if no common portal version is found - let _ = ?dst.highestCommonPortalVersion(localSupportedVersions) + let _ = ?dst.highestCommonPortalVersionAndChain(localPortalEnrField) logScope: node = dst @@ -1053,7 +1053,7 @@ proc offer( ## guarantee content transfer. # Fail if no common portal version is found - let _ = ?o.dst.highestCommonPortalVersion(localSupportedVersions) + let _ = ?o.dst.highestCommonPortalVersionAndChain(localPortalEnrField) let contentKeys = getContentKeys(o) diff --git a/portal/network/wire/portal_protocol_version.nim b/portal/network/wire/portal_protocol_version.nim index d898751717..89abc18e72 100644 --- a/portal/network/wire/portal_protocol_version.nim +++ b/portal/network/wire/portal_protocol_version.nim @@ -8,41 +8,69 @@ {.push raises: [].} import - std/sequtils, - ssz_serialization, + eth/rlp, eth/p2p/discoveryv5/[enr, node], + eth/common/base_rlp, ../../common/common_types -export ssz_serialization +export base_rlp -type PortalVersionValue* = List[uint8, 8] +type PortalEnrField* = object + pvMin: uint8 + pvMax: uint8 + chainId: ChainId const - portalVersionKey* = "pv" - localSupportedVersions* = PortalVersionValue(@[1'u8]) + portalEnrKey* = "p" + localSupportedVersionMin* = 2'u8 + localSupportedVersionMax* = 2'u8 + localChainId* = 1.chainId() # Mainnet by default, TODO: runtime configuration + localPortalEnrField* = PortalEnrField( + pvMin: localSupportedVersionMin, + pvMax: localSupportedVersionMax, + chainId: localChainId, + ) -func getPortalVersions(record: Record): Result[PortalVersionValue, string] = - let valueBytes = record.get(portalVersionKey, seq[byte]).valueOr: - return ok(PortalVersionValue(@[0'u8])) +func init*(T: type PortalEnrField, pvMin: uint8, pvMax: uint8, chainId: ChainId): T = + T(pvMin: pvMin, pvMax: pvMax, chainId: chainId) - decodeSsz(valueBytes, PortalVersionValue) +func getPortalEnrField(record: Record): Result[PortalEnrField, string] = + let valueBytes = record.get(portalEnrKey, seq[byte]).valueOr: + # When no field, default to version 0 and mainnet chainId + return ok(PortalEnrField(pvMin: 0'u8, pvMax: 0'u8, chainId: 1.chainId())) -func highestCommonPortalVersion( - versions: PortalVersionValue, supportedVersions: PortalVersionValue + let portalField = decodeRlp(valueBytes, PortalEnrField).valueOr: + return err("Failed to decode Portal field: " & error) + + if portalField.pvMin > portalField.pvMax: + return err("Invalid Portal ENR field: minimum version > maximum version") + + ok(portalField) + +func highestCommonPortalVersionAndChain( + a: PortalEnrField, b: PortalEnrField ): Result[uint8, string] = - let commonVersions = versions.filterIt(supportedVersions.contains(it)) - if commonVersions.len == 0: - return err("No common protocol versions found") + if a.chainId != b.chainId: + return err("ChainId mismatch: remote=" & $a.chainId & ", local=" & $b.chainId) + + let + commonMin = max(a.pvMin, b.pvMin) + commonMax = min(a.pvMax, b.pvMax) + + if commonMin > commonMax: + return err("No common Portal wire protocol version found") - ok(max(commonVersions)) + ok(commonMax) -func highestCommonPortalVersion*( - record: Record, supportedVersions: PortalVersionValue +func highestCommonPortalVersionAndChain*( + record: Record, supportedPortalField: PortalEnrField ): Result[uint8, string] = - let versions = ?record.getPortalVersions() - versions.highestCommonPortalVersion(supportedVersions) + ## Return highest common portal protocol version of both ENRs, but only if chainIds match + let portalField = ?record.getPortalEnrField() + portalField.highestCommonPortalVersionAndChain(supportedPortalField) -func highestCommonPortalVersion*( - node: Node, supportedVersions: PortalVersionValue +func highestCommonPortalVersionAndChain*( + node: Node, supportedPortalField: PortalEnrField ): Result[uint8, string] = - node.record.highestCommonPortalVersion(supportedVersions) + ## Return highest common portal protocol version of both nodes, but only if chainIds match + node.record.highestCommonPortalVersionAndChain(supportedPortalField) diff --git a/portal/tests/test_helpers.nim b/portal/tests/test_helpers.nim index 141261b5b1..6a21cd1b4e 100644 --- a/portal/tests/test_helpers.nim +++ b/portal/tests/test_helpers.nim @@ -32,7 +32,7 @@ proc initDiscoveryNode*( enrFields.add(localEnrFields) # Always inject the portal wire version field into the ENR # When no field, it would mean v0 only support - enrFields.add((portalVersionKey, SSZ.encode(localSupportedVersions))) + enrFields.add((portalEnrKey, rlp.encode(localPortalEnrField))) result = newProtocol( privKey, diff --git a/portal/tests/wire_protocol_tests/test_portal_wire_version.nim b/portal/tests/wire_protocol_tests/test_portal_wire_version.nim index f72862e3a4..913520d9cc 100644 --- a/portal/tests/wire_protocol_tests/test_portal_wire_version.nim +++ b/portal/tests/wire_protocol_tests/test_portal_wire_version.nim @@ -25,65 +25,77 @@ suite "Portal Wire Protocol Version": test "ENR with no Portal version field": let - localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8]) + localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId()) enr = Record.init(1, pk, ip, port, port, []).expect("Valid ENR init") - let version = enr.highestCommonPortalVersion(localSupportedVersions) + let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField) check: version.isOk() version.get() == 0'u8 - test "ENR with empty Portal version list": + test "ENR with empty Portal ENR field list": let - localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8]) - portalVersions = PortalVersionValue(@[]) - customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))] + localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId()) + portalEnrField = @[byte 0xc0] # Empty rlp list + customEnrFields = [toFieldPair(portalEnrKey, portalEnrField)] enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init") - let version = enr.highestCommonPortalVersion(localSupportedVersions) + let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField) check version.isErr() test "ENR with unsupported Portal versions": let - localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8]) - portalVersions = PortalVersionValue(@[255'u8, 100'u8, 2'u8]) - customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))] + localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId()) + portalEnrField = PortalEnrField.init(2'u8, 255'u8, 2.chainId()) + + customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))] enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init") - let version = enr.highestCommonPortalVersion(localSupportedVersions) + let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField) check version.isErr() test "ENR with supported Portal version": let - localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8]) - portalVersions = PortalVersionValue(@[3'u8, 2'u8, 1'u8]) - customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))] + localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId()) + portalEnrField = PortalEnrField.init(1'u8, 3'u8, 1.chainId()) + + customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))] enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init") - let version = enr.highestCommonPortalVersion(localSupportedVersions) + let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField) check: version.isOk() version.get() == 1'u8 test "ENR with multiple supported Portal versions": let - localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8, 2'u8]) - portalVersions = PortalVersionValue(@[0'u8, 2'u8, 2'u8, 3'u8]) - customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))] + localPortalEnrField = PortalEnrField.init(0'u8, 2'u8, 1.chainId()) + portalEnrField = PortalEnrField.init(0'u8, 3'u8, 1.chainId()) + customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))] enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init") - let version = enr.highestCommonPortalVersion(localSupportedVersions) + let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField) check: version.isOk() version.get() == 2'u8 - test "ENR with too many Portal versions": + test "ENR with invalid Portal version range (min > max)": + let + localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId()) + portalEnrField = PortalEnrField.init(2'u8, 1'u8, 1.chainId()) + customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))] + enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init") + + let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField) + check version.isErr() + + test "ENR with supported Portal version but different chain id": let - localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8, 2'u8]) - portalVersions = - PortalVersionValue(@[0'u8, 1'u8, 2'u8, 3'u8, 4'u8, 5'u8, 6'u8, 7'u8, 8'u8]) - customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))] + localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId()) + portalEnrField = PortalEnrField.init(1'u8, 3'u8, 2.chainId()) + + customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))] enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init") - let version = enr.highestCommonPortalVersion(localSupportedVersions) + let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField) check version.isErr()