Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion portal/client/nimbus_portal_client.nim
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ proc run(portalClient: PortalClient, config: PortalConf) {.raises: [CatchableErr
# measure to easily identify & debug the clients used in the testnet.
# Might make this into a, default off, cli option.
localEnrFields =
{"c": enrClientInfoShort, portalVersionKey: SSZ.encode(localSupportedVersions)},
{"c": enrClientInfoShort, portalEnrKey: rlp.encode(localPortalEnrField)},
bootstrapRecords = bootstrapRecords,
previousRecord = previousEnr,
bindIp = bindIp,
Expand Down
14 changes: 7 additions & 7 deletions portal/network/wire/portal_protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func fromNodeStatus(T: type NodeAddResult, status: NodeStatus): T =
of NodeStatus.Banned: T.Banned

proc addNode*(p: PortalProtocol, node: Node): NodeAddResult =
if node.highestCommonPortalVersion(localSupportedVersions).isOk():
if node.highestCommonPortalVersionAndChain(localPortalEnrField).isOk():
let status = p.routingTable.addNode(node)
trace "Adding node to routing table", status, node
NodeAddResult.fromNodeStatus(status)
Expand Down Expand Up @@ -672,8 +672,8 @@ proc messageHandler(
warn "No ENR found for node", srcId, srcUdpAddress
return @[]

let _ = enr.highestCommonPortalVersion(localSupportedVersions).valueOr:
debug "No compatible protocol version found", error, srcId, srcUdpAddress
let _ = enr.highestCommonPortalVersionAndChain(localPortalEnrField).valueOr:
debug "Incompatible protocols", error, srcId, srcUdpAddress
return @[]

let decoded = decodeMessage(request)
Expand Down Expand Up @@ -884,7 +884,7 @@ proc ping*(
async: (raises: [CancelledError])
.} =
# Fail if no common portal version is found
let _ = ?dst.highestCommonPortalVersion(localSupportedVersions)
let _ = ?dst.highestCommonPortalVersionAndChain(localPortalEnrField)

if p.isBanned(dst.id):
return err("destination node is banned")
Expand All @@ -910,7 +910,7 @@ proc findNodes*(
p: PortalProtocol, dst: Node, distances: seq[uint16]
): Future[PortalResult[seq[Node]]] {.async: (raises: [CancelledError]).} =
# Fail if no common portal version is found
let _ = ?dst.highestCommonPortalVersion(localSupportedVersions)
let _ = ?dst.highestCommonPortalVersionAndChain(localPortalEnrField)

if p.isBanned(dst.id):
return err("destination node is banned")
Expand All @@ -928,7 +928,7 @@ proc findContent*(
p: PortalProtocol, dst: Node, contentKey: ContentKeyByteList
): Future[PortalResult[FoundContent]] {.async: (raises: [CancelledError]).} =
# Fail if no common portal version is found
let _ = ?dst.highestCommonPortalVersion(localSupportedVersions)
let _ = ?dst.highestCommonPortalVersionAndChain(localPortalEnrField)

logScope:
node = dst
Expand Down Expand Up @@ -1053,7 +1053,7 @@ proc offer(
## guarantee content transfer.

# Fail if no common portal version is found
let _ = ?o.dst.highestCommonPortalVersion(localSupportedVersions)
let _ = ?o.dst.highestCommonPortalVersionAndChain(localPortalEnrField)

let contentKeys = getContentKeys(o)

Expand Down
74 changes: 51 additions & 23 deletions portal/network/wire/portal_protocol_version.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,69 @@
{.push raises: [].}

import
std/sequtils,
ssz_serialization,
eth/rlp,
eth/p2p/discoveryv5/[enr, node],
eth/common/base_rlp,
../../common/common_types

export ssz_serialization
export base_rlp

type PortalVersionValue* = List[uint8, 8]
type PortalEnrField* = object
pvMin: uint8
pvMax: uint8
chainId: ChainId

const
portalVersionKey* = "pv"
localSupportedVersions* = PortalVersionValue(@[1'u8])
portalEnrKey* = "p"
localSupportedVersionMin* = 2'u8
localSupportedVersionMax* = 2'u8
localChainId* = 1.chainId() # Mainnet by default, TODO: runtime configuration
localPortalEnrField* = PortalEnrField(
pvMin: localSupportedVersionMin,
pvMax: localSupportedVersionMax,
chainId: localChainId,
)

func getPortalVersions(record: Record): Result[PortalVersionValue, string] =
let valueBytes = record.get(portalVersionKey, seq[byte]).valueOr:
return ok(PortalVersionValue(@[0'u8]))
func init*(T: type PortalEnrField, pvMin: uint8, pvMax: uint8, chainId: ChainId): T =
T(pvMin: pvMin, pvMax: pvMax, chainId: chainId)

decodeSsz(valueBytes, PortalVersionValue)
func getPortalEnrField(record: Record): Result[PortalEnrField, string] =
let valueBytes = record.get(portalEnrKey, seq[byte]).valueOr:
# When no field, default to version 0 and mainnet chainId
return ok(PortalEnrField(pvMin: 0'u8, pvMax: 0'u8, chainId: 1.chainId()))

func highestCommonPortalVersion(
versions: PortalVersionValue, supportedVersions: PortalVersionValue
let portalField = decodeRlp(valueBytes, PortalEnrField).valueOr:
return err("Failed to decode Portal field: " & error)

if portalField.pvMin > portalField.pvMax:
return err("Invalid Portal ENR field: minimum version > maximum version")

ok(portalField)

func highestCommonPortalVersionAndChain(
a: PortalEnrField, b: PortalEnrField
): Result[uint8, string] =
let commonVersions = versions.filterIt(supportedVersions.contains(it))
if commonVersions.len == 0:
return err("No common protocol versions found")
if a.chainId != b.chainId:
return err("ChainId mismatch: remote=" & $a.chainId & ", local=" & $b.chainId)

let
commonMin = max(a.pvMin, b.pvMin)
commonMax = min(a.pvMax, b.pvMax)

if commonMin > commonMax:
return err("No common Portal wire protocol version found")

ok(max(commonVersions))
ok(commonMax)

func highestCommonPortalVersion*(
record: Record, supportedVersions: PortalVersionValue
func highestCommonPortalVersionAndChain*(
record: Record, supportedPortalField: PortalEnrField
): Result[uint8, string] =
let versions = ?record.getPortalVersions()
versions.highestCommonPortalVersion(supportedVersions)
## Return highest common portal protocol version of both ENRs, but only if chainIds match
let portalField = ?record.getPortalEnrField()
portalField.highestCommonPortalVersionAndChain(supportedPortalField)

func highestCommonPortalVersion*(
node: Node, supportedVersions: PortalVersionValue
func highestCommonPortalVersionAndChain*(
node: Node, supportedPortalField: PortalEnrField
): Result[uint8, string] =
node.record.highestCommonPortalVersion(supportedVersions)
## Return highest common portal protocol version of both nodes, but only if chainIds match
node.record.highestCommonPortalVersionAndChain(supportedPortalField)
2 changes: 1 addition & 1 deletion portal/tests/test_helpers.nim
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ proc initDiscoveryNode*(
enrFields.add(localEnrFields)
# Always inject the portal wire version field into the ENR
# When no field, it would mean v0 only support
enrFields.add((portalVersionKey, SSZ.encode(localSupportedVersions)))
enrFields.add((portalEnrKey, rlp.encode(localPortalEnrField)))

result = newProtocol(
privKey,
Expand Down
62 changes: 37 additions & 25 deletions portal/tests/wire_protocol_tests/test_portal_wire_version.nim
Original file line number Diff line number Diff line change
Expand Up @@ -25,65 +25,77 @@ suite "Portal Wire Protocol Version":

test "ENR with no Portal version field":
let
localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8])
localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId())
enr = Record.init(1, pk, ip, port, port, []).expect("Valid ENR init")

let version = enr.highestCommonPortalVersion(localSupportedVersions)
let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField)
check:
version.isOk()
version.get() == 0'u8

test "ENR with empty Portal version list":
test "ENR with empty Portal ENR field list":
let
localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8])
portalVersions = PortalVersionValue(@[])
customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))]
localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId())
portalEnrField = @[byte 0xc0] # Empty rlp list
customEnrFields = [toFieldPair(portalEnrKey, portalEnrField)]
enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init")

let version = enr.highestCommonPortalVersion(localSupportedVersions)
let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField)
check version.isErr()

test "ENR with unsupported Portal versions":
let
localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8])
portalVersions = PortalVersionValue(@[255'u8, 100'u8, 2'u8])
customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))]
localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId())
portalEnrField = PortalEnrField.init(2'u8, 255'u8, 2.chainId())

customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))]
enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init")

let version = enr.highestCommonPortalVersion(localSupportedVersions)
let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField)
check version.isErr()

test "ENR with supported Portal version":
let
localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8])
portalVersions = PortalVersionValue(@[3'u8, 2'u8, 1'u8])
customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))]
localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId())
portalEnrField = PortalEnrField.init(1'u8, 3'u8, 1.chainId())

customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))]
enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init")

let version = enr.highestCommonPortalVersion(localSupportedVersions)
let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField)
check:
version.isOk()
version.get() == 1'u8

test "ENR with multiple supported Portal versions":
let
localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8, 2'u8])
portalVersions = PortalVersionValue(@[0'u8, 2'u8, 2'u8, 3'u8])
customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))]
localPortalEnrField = PortalEnrField.init(0'u8, 2'u8, 1.chainId())
portalEnrField = PortalEnrField.init(0'u8, 3'u8, 1.chainId())
customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))]
enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init")

let version = enr.highestCommonPortalVersion(localSupportedVersions)
let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField)
check:
version.isOk()
version.get() == 2'u8

test "ENR with too many Portal versions":
test "ENR with invalid Portal version range (min > max)":
let
localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId())
portalEnrField = PortalEnrField.init(2'u8, 1'u8, 1.chainId())
customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))]
enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init")

let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField)
check version.isErr()

test "ENR with supported Portal version but different chain id":
let
localSupportedVersions = PortalVersionValue(@[0'u8, 1'u8, 2'u8])
portalVersions =
PortalVersionValue(@[0'u8, 1'u8, 2'u8, 3'u8, 4'u8, 5'u8, 6'u8, 7'u8, 8'u8])
customEnrFields = [toFieldPair(portalVersionKey, SSZ.encode(portalVersions))]
localPortalEnrField = PortalEnrField.init(0'u8, 1'u8, 1.chainId())
portalEnrField = PortalEnrField.init(1'u8, 3'u8, 2.chainId())

customEnrFields = [toFieldPair(portalEnrKey, rlp.encode(portalEnrField))]
enr = Record.init(1, pk, ip, port, port, customEnrFields).expect("Valid ENR init")

let version = enr.highestCommonPortalVersion(localSupportedVersions)
let version = enr.highestCommonPortalVersionAndChain(localPortalEnrField)
check version.isErr()
Loading