New macro PSA_KEY_TYPE_IS_RSA
diff --git a/include/psa/crypto.h b/include/psa/crypto.h
index 68e3b0a..ba0755b 100644
--- a/include/psa/crypto.h
+++ b/include/psa/crypto.h
@@ -433,7 +433,11 @@
 /** Whether a key type is an RSA key pair or public key. */
 #define PSA_KEY_TYPE_IS_RSA(type)                                       \
     (PSA_KEY_TYPE_PUBLIC_KEY_OF_KEYPAIR(type) == PSA_KEY_TYPE_RSA_PUBLIC_KEY)
-/** Whether a key type is an elliptic curve key pair or public key. */
+/** Whether a key type is an RSA key (pair or public-only). */
+#define PSA_KEY_TYPE_IS_RSA(type)                                       \
+    (PSA_KEY_TYPE_PUBLIC_KEY_OF_KEYPAIR(type) ==                        \
+     PSA_KEY_TYPE_RSA_PUBLIC_KEY)
+/** Whether a key type is an elliptic curve key (pair or public-only). */
 #define PSA_KEY_TYPE_IS_ECC(type)                                       \
     ((PSA_KEY_TYPE_PUBLIC_KEY_OF_KEYPAIR(type) &                        \
       ~PSA_KEY_TYPE_ECC_CURVE_MASK) == PSA_KEY_TYPE_ECC_PUBLIC_KEY_BASE)
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index a1b8104..fac1c75 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -567,9 +567,7 @@
     }
     else
 #if defined(MBEDTLS_PK_PARSE_C)
-    if( type == PSA_KEY_TYPE_RSA_PUBLIC_KEY ||
-        type == PSA_KEY_TYPE_RSA_KEYPAIR ||
-        PSA_KEY_TYPE_IS_ECC( type ) )
+    if( PSA_KEY_TYPE_IS_RSA( type ) || PSA_KEY_TYPE_IS_ECC( type ) )
     {
         int ret;
         mbedtls_pk_context pk;
@@ -584,8 +582,7 @@
         {
 #if defined(MBEDTLS_RSA_C)
             case MBEDTLS_PK_RSA:
-                if( type == PSA_KEY_TYPE_RSA_PUBLIC_KEY ||
-                    type == PSA_KEY_TYPE_RSA_KEYPAIR )
+                if( PSA_KEY_TYPE_IS_RSA( type ) )
                 {
                     mbedtls_rsa_context *rsa = mbedtls_pk_rsa( pk );
                     size_t bits = mbedtls_rsa_get_bitlen( rsa );
@@ -662,8 +659,7 @@
     }
     else
 #if defined(MBEDTLS_RSA_C)
-    if( slot->type == PSA_KEY_TYPE_RSA_PUBLIC_KEY ||
-        slot->type == PSA_KEY_TYPE_RSA_KEYPAIR )
+    if( PSA_KEY_TYPE_IS_RSA( slot->type ) )
     {
         mbedtls_rsa_free( slot->data.rsa );
         mbedtls_free( slot->data.rsa );
@@ -694,8 +690,7 @@
     if( key_type_is_raw_bytes( slot->type ) )
         return( slot->data.raw.bytes * 8 );
 #if defined(MBEDTLS_RSA_C)
-    if( slot->type == PSA_KEY_TYPE_RSA_PUBLIC_KEY ||
-        slot->type == PSA_KEY_TYPE_RSA_KEYPAIR )
+    if( PSA_KEY_TYPE_IS_RSA( slot->type ) )
         return( mbedtls_rsa_get_bitlen( slot->data.rsa ) );
 #endif /* defined(MBEDTLS_RSA_C) */
 #if defined(MBEDTLS_ECP_C)
@@ -769,15 +764,13 @@
     else
     {
 #if defined(MBEDTLS_PK_WRITE_C)
-        if( slot->type == PSA_KEY_TYPE_RSA_PUBLIC_KEY ||
-            slot->type == PSA_KEY_TYPE_RSA_KEYPAIR ||
+        if( PSA_KEY_TYPE_IS_RSA( slot->type ) ||
             PSA_KEY_TYPE_IS_ECC( slot->type ) )
         {
             mbedtls_pk_context pk;
             int ret;
             mbedtls_pk_init( &pk );
-            if( slot->type == PSA_KEY_TYPE_RSA_PUBLIC_KEY ||
-                slot->type == PSA_KEY_TYPE_RSA_KEYPAIR )
+            if( PSA_KEY_TYPE_IS_RSA( slot->type ) )
             {
                 pk.pk_info = &mbedtls_rsa_info;
                 pk.pk_ctx = slot->data.rsa;
@@ -2064,8 +2057,7 @@
         return( status );
 
 #if defined(MBEDTLS_RSA_C)
-    if( slot->type == PSA_KEY_TYPE_RSA_KEYPAIR ||
-        slot->type == PSA_KEY_TYPE_RSA_PUBLIC_KEY )
+    if( PSA_KEY_TYPE_IS_RSA( slot->type ) )
     {
         return( psa_rsa_verify( slot->data.rsa,
                                 alg,
@@ -2120,8 +2112,7 @@
         return( PSA_ERROR_INVALID_ARGUMENT );
 
 #if defined(MBEDTLS_RSA_C)
-    if( slot->type == PSA_KEY_TYPE_RSA_KEYPAIR ||
-        slot->type == PSA_KEY_TYPE_RSA_PUBLIC_KEY )
+    if( PSA_KEY_TYPE_IS_RSA( slot->type ) )
     {
         mbedtls_rsa_context *rsa = slot->data.rsa;
         int ret;