1
// Copyright (C) Moondance Labs Ltd.
2
// This file is part of Tanssi.
3

            
4
// Tanssi is free software: you can redistribute it and/or modify
5
// it under the terms of the GNU General Public License as published by
6
// the Free Software Foundation, either version 3 of the License, or
7
// (at your option) any later version.
8

            
9
// Tanssi is distributed in the hope that it will be useful,
10
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
// GNU General Public License for more details.
13

            
14
// You should have received a copy of the GNU General Public License
15
// along with Tanssi.  If not, see <http://www.gnu.org/licenses/>
16
use alloy_core::sol;
17
use frame_support::pallet_prelude::{Decode, Encode};
18
use frame_support::DebugNoBound;
19
use scale_info::TypeInfo;
20
use sp_core::{ConstU32, DecodeWithMemTracking};
21
use sp_runtime::BoundedVec;
22

            
23
/// Maximum size of a LayerZero message payload in bytes (8 KB).
24
/// This limit prevents memory exhaustion from arbitrarily large payloads.
25
pub const MAX_LAYERZERO_PAYLOAD_SIZE: u32 = 8 * 1024;
26

            
27
/// Bounded payload type for inbound LayerZero messages.
28
pub type LayerZeroInboundPayload = BoundedVec<u8, ConstU32<MAX_LAYERZERO_PAYLOAD_SIZE>>;
29

            
30
/// Bounded payload type for outbound LayerZero messages.
31
pub type LayerZeroOutboundPayload = BoundedVec<u8, ConstU32<MAX_LAYERZERO_PAYLOAD_SIZE>>;
32

            
33
pub type LayerZeroAddress = BoundedVec<u8, ConstU32<32>>;
34
pub type LayerZeroEndpoint = u32;
35

            
36
sol! {
37
    struct InboundSolMessage {
38
        bytes32 lzSourceAddress;
39
        uint32  lzSourceEndpoint;
40
        uint32  destinationChain;
41
        bytes   payload;
42
    }
43
}
44

            
45
#[derive(Encode, Decode, DecodeWithMemTracking, Clone, DebugNoBound, PartialEq, Eq, TypeInfo)]
46
pub struct InboundMessage {
47
    pub lz_source_address: LayerZeroAddress,
48
    pub lz_source_endpoint: LayerZeroEndpoint,
49
    pub destination_chain: u32,
50
    pub payload: LayerZeroInboundPayload,
51
}
52

            
53
/// Error when converting from InboundSolMessage to InboundMessage
54
#[derive(Debug, Clone, PartialEq, Eq)]
55
pub enum InboundMessageConversionError {
56
    /// The message payload exceeds the maximum allowed size
57
    PayloadTooLarge { size: usize, max: u32 },
58
}
59

            
60
impl core::fmt::Display for InboundMessageConversionError {
61
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
62
        match self {
63
            Self::PayloadTooLarge { size, max } => {
64
                write!(f, "payload size {} exceeds maximum {}", size, max)
65
            }
66
        }
67
    }
68
}
69

            
70
// from InboundSolMessage to InboundMessage
71
impl TryFrom<InboundSolMessage> for InboundMessage {
72
    type Error = InboundMessageConversionError;
73

            
74
33
    fn try_from(sol_message: InboundSolMessage) -> Result<Self, Self::Error> {
75
33
        let payload_bytes: alloc::vec::Vec<u8> = sol_message.payload.into();
76
33
        let payload_len = payload_bytes.len();
77
33
        let payload: LayerZeroInboundPayload = payload_bytes.try_into().map_err(|_| {
78
            InboundMessageConversionError::PayloadTooLarge {
79
                size: payload_len,
80
                max: MAX_LAYERZERO_PAYLOAD_SIZE,
81
            }
82
        })?;
83

            
84
33
        Ok(Self {
85
33
            lz_source_address: sol_message
86
33
                .lzSourceAddress
87
33
                .to_vec()
88
33
                .try_into()
89
33
                .expect("lzSourceAddress is always 32 bytes; qed"),
90
33
            lz_source_endpoint: sol_message.lzSourceEndpoint,
91
33
            destination_chain: sol_message.destinationChain,
92
33
            payload,
93
33
        })
94
33
    }
95
}
96

            
97
sol! {
98
    struct OutboundSolMessage {
99
        uint32  sourceChain;
100
        bytes32 lzDestinationAddress;
101
        uint32  lzDestinationEndpoint;
102
        bytes   payload;
103
    }
104
}
105

            
106
#[derive(Encode, Decode, DecodeWithMemTracking, Clone, DebugNoBound, PartialEq, Eq, TypeInfo)]
107
pub struct OutboundMessage {
108
    pub source_chain: u32,
109
    pub lz_destination_address: LayerZeroAddress,
110
    pub lz_destination_endpoint: LayerZeroEndpoint,
111
    pub payload: LayerZeroOutboundPayload,
112
}
113

            
114
// from OutboundMessage to OutboundSolMessage
115
impl From<OutboundMessage> for OutboundSolMessage {
116
297
    fn from(message: OutboundMessage) -> Self {
117
297
        let mut destination_address = [0u8; 32];
118
297
        let addr_slice = message.lz_destination_address.as_slice();
119
297
        let len = addr_slice.len().min(32);
120
297
        destination_address[..len].copy_from_slice(&addr_slice[..len]);
121

            
122
297
        Self {
123
297
            sourceChain: message.source_chain,
124
297
            lzDestinationAddress: alloy_core::primitives::FixedBytes(destination_address),
125
297
            lzDestinationEndpoint: message.lz_destination_endpoint,
126
297
            payload: message.payload.to_vec().into(),
127
297
        }
128
297
    }
129
}