Handle top bits in upper VA ranges

Set and enforce top bits in virtual address if Upper VA range is
selected for EL1/0 or EL2/0 (VHE). This means that the top bits must
match the selected VA range and Xlat will also return matching virtual
addresses.

Signed-off-by: Imre Kis <imre.kis@arm.com>
Change-Id: I44c2a326a9d3fdd4d82ec01e8f95d1c8f7d305b1
diff --git a/src/address.rs b/src/address.rs
index 5068f29..c5cb6ff 100644
--- a/src/address.rs
+++ b/src/address.rs
@@ -3,6 +3,8 @@
 
 use core::{fmt, ops::Range};
 
+use crate::TranslationRegime;
+
 use super::TranslationGranule;
 
 #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
@@ -94,6 +96,27 @@
         self.0 >> translation_granule.total_bits_at_level(level)
     }
 
+    pub fn is_valid_in_regime<const VA_BITS: usize>(&self, regime: TranslationRegime) -> bool {
+        let mask = Self::get_upper_bit_mask::<VA_BITS>();
+        let required_upper_bits = if regime.is_upper_va_range() { mask } else { 0 };
+
+        (self.0 & mask) == required_upper_bits
+    }
+
+    pub fn set_upper_bits<const VA_BITS: usize>(self, regime: TranslationRegime) -> Self {
+        let mask = Self::get_upper_bit_mask::<VA_BITS>();
+
+        Self(if regime.is_upper_va_range() {
+            self.0 | mask
+        } else {
+            self.0 & !mask
+        })
+    }
+
+    pub fn remove_upper_bits<const VA_BITS: usize>(self) -> Self {
+        Self(self.0 & !Self::get_upper_bit_mask::<VA_BITS>())
+    }
+
     pub const fn mask_bits(self, mask: usize) -> Self {
         Self(self.0 & mask)
     }
@@ -105,6 +128,10 @@
     pub const fn align_up(self, alignment: usize) -> Self {
         Self(self.0.next_multiple_of(alignment))
     }
+
+    const fn get_upper_bit_mask<const VA_BITS: usize>() -> usize {
+        !((1 << VA_BITS) - 1)
+    }
 }
 
 impl From<VirtualAddress> for usize {
@@ -127,6 +154,7 @@
     }
 }
 
+#[derive(Debug)]
 pub struct VirtualAddressRange {
     pub(super) start: VirtualAddress,
     pub(super) end: VirtualAddress,
diff --git a/src/kernel_space.rs b/src/kernel_space.rs
index f9e1d68..cc04541 100644
--- a/src/kernel_space.rs
+++ b/src/kernel_space.rs
@@ -37,7 +37,9 @@
         Self {
             xlat: Arc::new(Mutex::new(Xlat::new(
                 page_pool,
-                unsafe { VirtualAddressRange::from_range(0x0000_0000..0x10_0000_0000) },
+                unsafe {
+                    VirtualAddressRange::from_range(0xffff_fff0_0000_0000..0xffff_ffff_ffff_ffff)
+                },
                 TranslationRegime::EL1_0(RegimeVaRange::Upper, 0),
                 TranslationGranule::Granule4k,
             ))),
@@ -57,18 +59,26 @@
     ) -> Result<(), XlatError> {
         let mut xlat = self.xlat.lock();
 
-        let code_pa = PhysicalAddress(code_range.start);
-        let data_pa = PhysicalAddress(data_range.start);
+        let code_pa = PhysicalAddress(code_range.start & 0x0000_000f_ffff_ffff);
+        let data_pa = PhysicalAddress(data_range.start & 0x0000_000f_ffff_ffff);
 
         xlat.map_physical_address_range(
-            Some(code_pa.identity_va()),
+            Some(
+                code_pa
+                    .identity_va()
+                    .set_upper_bits::<36>(TranslationRegime::EL1_0(RegimeVaRange::Upper, 0)),
+            ),
             code_pa,
             code_range.len(),
             MemoryAccessRights::RX | MemoryAccessRights::GLOBAL,
         )?;
 
         xlat.map_physical_address_range(
-            Some(data_pa.identity_va()),
+            Some(
+                data_pa
+                    .identity_va()
+                    .set_upper_bits::<36>(TranslationRegime::EL1_0(RegimeVaRange::Upper, 0)),
+            ),
             data_pa,
             data_range.len(),
             MemoryAccessRights::RW | MemoryAccessRights::GLOBAL,
@@ -92,14 +102,17 @@
     ) -> Result<usize, XlatError> {
         let pa = PhysicalAddress(pa);
 
-        let lower_va = self.xlat.lock().map_physical_address_range(
-            Some(pa.identity_va()),
+        let va = self.xlat.lock().map_physical_address_range(
+            Some(
+                pa.identity_va()
+                    .set_upper_bits::<36>(TranslationRegime::EL1_0(RegimeVaRange::Upper, 0)),
+            ),
             pa,
             length,
             access_rights | MemoryAccessRights::GLOBAL,
         )?;
 
-        Ok(Self::pa_to_kernel(lower_va.0 as u64) as usize)
+        Ok(va.0)
     }
 
     /// Unmap memory range from the kernel address space
@@ -109,10 +122,9 @@
     /// # Return value
     /// The result of the operation
     pub fn unmap_memory(&self, va: usize, length: usize) -> Result<(), XlatError> {
-        self.xlat.lock().unmap_virtual_address_range(
-            VirtualAddress(Self::kernel_to_pa(va as u64) as usize),
-            length,
-        )
+        self.xlat
+            .lock()
+            .unmap_virtual_address_range(VirtualAddress(va), length)
     }
 
     /// Activate kernel address space mapping
diff --git a/src/lib.rs b/src/lib.rs
index b5f95d6..c8abbdd 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -114,13 +114,13 @@
     }
 }
 
-#[derive(Debug)]
+#[derive(Debug, Clone, Copy)]
 pub enum RegimeVaRange {
     Lower,
     Upper,
 }
 
-#[derive(Debug)]
+#[derive(Debug, Clone, Copy)]
 pub enum TranslationRegime {
     EL1_0(RegimeVaRange, u8), // EL1 and EL0 stage 1, TTBRx_EL1
     #[cfg(target_feature = "vh")]
@@ -129,6 +129,17 @@
     EL3,                      // EL3, TTBR0_EL3
 }
 
+impl TranslationRegime {
+    fn is_upper_va_range(&self) -> bool {
+        match self {
+            TranslationRegime::EL1_0(RegimeVaRange::Upper, _) => true,
+            #[cfg(target_feature = "vh")]
+            EL2_0(RegimeVaRange::Upper, _) => true,
+            _ => false,
+        }
+    }
+}
+
 pub type TranslationGranule<const VA_BITS: usize> = granule::TranslationGranule<VA_BITS>;
 
 pub struct Xlat<const VA_BITS: usize> {
@@ -173,6 +184,15 @@
     ) -> Self {
         let initial_lookup_level = granule.initial_lookup_level();
 
+        if !address.start.is_valid_in_regime::<VA_BITS>(regime)
+            || !address.end.is_valid_in_regime::<VA_BITS>(regime)
+        {
+            panic!(
+                "Invalid address range {:?} for regime {:?}",
+                address, regime
+            );
+        }
+
         let base_table = page_pool
             .allocate_pages(
                 granule.table_size::<Descriptor>(initial_lookup_level),
@@ -547,7 +567,7 @@
     ) -> Result<VirtualAddress, XlatError> {
         let blocks = BlockIterator::new(
             region.get_pa(),
-            region.base(),
+            region.base().remove_upper_bits::<VA_BITS>(),
             region.length(),
             self.granule,
         )?;
@@ -565,7 +585,7 @@
     fn unmap_region(&mut self, region: &VirtualRegion) -> Result<(), XlatError> {
         let blocks = BlockIterator::new(
             region.get_pa(),
-            region.base(),
+            region.base().remove_upper_bits::<VA_BITS>(),
             region.length(),
             self.granule,
         )?;