RAG Protein Tracker
Protein tracker app that uses RAG and LangChain
Tech Stack
Front and Back End
- React Native, Expo, Redux, Fast API, LangChain
- OpenAI
- Render.com
Features
Home Tab
Recipes Tab
Search
Technical Challenges
RAG
In many cases, users may not know the amount of protein they need to add, they just have a general description. RAG ended up being a simple solution for this feature. On the backend, there’s a sqlite database with hundreds of thousands of generic and branded foods. During the retrieval phase of RAG, a full text search is performed over the database to match foods from the input text and get the nutritional information. After that, a query to Open AI is augmented with this data to provide the AI model with accurate context so it doesn’t hallucinate. The generated response is then returned to the front end in JSON and presented to the user.
Backend Request Integrity
Since RAG requests are handled over an API endpoint, leaving that exposed to the public internet would risk the possibility of an attacker spamming the endpoint and racking up the OpenAI bill. Both Apple and Google have solutions for proving requests to a backend are coming from a valid mobile app instance and not from some other malicious source. Apple provides something called App Attest Service, and Google provides its Play Integrity API.
Implementing these solutions involved creating a custom native module for Expo, and writing some backend logic to verify the information provided by App Attest and the Play Integrity API.
For iOS, the flow goes like this:
For android, the flow is a little more simple, but we have to manually increment the challenge counter since the Play Integrity token doesn’t include a counter to protect against replay attacks, you have to handle that information yourself through the nonce. In my implementation, the challenge is formatted as <id>.<value>.<counter>, and the frontend increments the counter part for android devices. On iOS, it just leaves it the same and treats it as part of the value.
All in all, this was the main code for validating the token, attestation, and assertion respectively.
import base64
from cryptography.hazmat.primitives.keywrap import aes_key_unwrap
from cryptography.hazmat.backends import default_backend
from redis import Redis
import jwt
from cryptography.hazmat.primitives import serialization
from utils.get_secret import get_secret
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cache.store import CHALLENGE_PREFIX, CHALLENGE_COUNTER_BIT_LENGTH
# Load the Base64-encoded keys from your secret manager
GOOGLE_PLAY_INTEGRITY_DECRYPTION_KEY = get_secret(
"PLAY_INTEGRITY_DECRYPTION_KEY")
GOOGLE_PLAY_INTEGRITY_VERIFICATION_KEY = get_secret(
"PLAY_INTEGRITY_VERIFICATION_KEY")
# Adds Base64 padding if needed (to make length a multiple of 4)
def add_padding(b64_str):
return b64_str + '=' * (-len(b64_str) % 4)
def validate_challenge(redis_client: Redis, decoded_token: dict, challenge: str):
challenge_id = challenge.split('.')[0]
challenge_root_value = challenge.split('.')[1]
request_challenge = decoded_token['requestDetails']['nonce']
request_challenge_counter = request_challenge[-CHALLENGE_COUNTER_BIT_LENGTH:]
request_challenge_root_value = request_challenge[:-
CHALLENGE_COUNTER_BIT_LENGTH]
request_challenge_value = request_challenge_root_value + \
'.' + request_challenge_counter
stored_challenge = redis_client.get(f"{CHALLENGE_PREFIX}{challenge_id}")
stored_challenge_root_value = stored_challenge.split('.')[0]
stored_challenge_counter = stored_challenge.split('.')[1]
if not stored_challenge_root_value == challenge_root_value == \
request_challenge_root_value:
raise Exception("Invalid challenge")
if not int(request_challenge_counter) > int(stored_challenge_counter):
raise Exception("Invalid challenge counter")
redis_client.set(f"{CHALLENGE_PREFIX}{challenge_id}",
request_challenge_value)
def validate_token(redis_client: Redis, token: str, challenge: str) -> bool:
try:
# Decode the base64-encoded decryption and verification keys
decryption_key_bytes = base64.b64decode(
GOOGLE_PLAY_INTEGRITY_DECRYPTION_KEY)
verification_key_bytes = base64.b64decode(
GOOGLE_PLAY_INTEGRITY_VERIFICATION_KEY)
# Load the verification key into a usable public key object
public_key = serialization.load_der_public_key(verification_key_bytes)
# The token is a compact JWE (JSON Web Encryption) split into 5 parts:
# [JWE header, encrypted key, IV, ciphertext, auth tag]
parts = token.split('.')
header_b64 = parts[0]
encrypted_key_b64 = parts[1]
iv_b64 = parts[2]
ciphertext_b64 = parts[3]
auth_tag_b64 = parts[4]
# Decode each JWE part (with padding fixed)
header_decoded_bytes = base64.urlsafe_b64decode(
add_padding(header_b64))
encrypted_key_decoded_bytes = base64.urlsafe_b64decode(
add_padding(encrypted_key_b64))
iv_decoded_bytes = base64.urlsafe_b64decode(add_padding(iv_b64))
ciphertext_decoded_bytes = base64.urlsafe_b64decode(
add_padding(ciphertext_b64))
auth_tag_decoded_bytes = base64.urlsafe_b64decode(
add_padding(auth_tag_b64))
# Unwrap the CEK (Content Encryption Key) using the decryption key
cek = aes_key_unwrap(
decryption_key_bytes,
encrypted_key_decoded_bytes,
backend=default_backend()
)
# Decrypt the ciphertext using AES-GCM with the unwrapped CEK
aesgcm = AESGCM(cek)
# The Additional Authenticated Data (AAD) is the base64-encoded header
aad = base64.urlsafe_b64encode(header_decoded_bytes).rstrip(b"=")
# Perform the actual AES-GCM decryption
plaintext = aesgcm.decrypt(
iv_decoded_bytes,
ciphertext_decoded_bytes + auth_tag_decoded_bytes,
aad
)
# The plaintext is a JWS (JSON Web Signature) — verify its signature next
jws_token = plaintext.decode("utf-8")
# Validate the inner signed JWT using the Google-provided verification key
decoded_token = jwt.decode(
jws_token,
public_key,
algorithms=["ES256"],
# Google's token doesn't use an audience field, so skip that check
options={"verify_aud": False}
)
validate_challenge(redis_client, decoded_token, challenge)
return True
except Exception as e:
print(f"Error in validate_token: {e}")
return False
import base64
import hmac
import os
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.constant_time import bytes_eq
from cryptography.hazmat.primitives import serialization
def _to_fixed_length_hex(value: str, length: int = 16):
return value.encode().ljust(length, b'\x00')[:length].hex()
def _sha256(data: bytes) -> bytes:
h = hashes.Hash(hashes.SHA256(), default_backend())
h.update(data)
return h.finalize()
INTEGRITY_ENVIRONMENT = os.getenv("INTEGRITY_ENVIRONMENT", "development")
OID_APPLE = x509.ObjectIdentifier("1.2.840.113635.100.8.2")
PROD_AAGUID = 'appattest'
DEV_AAGUID = 'appattestdevelop'
PROD_AAGUID_HEX = _to_fixed_length_hex(PROD_AAGUID)
DEV_AAGUID_HEX = _to_fixed_length_hex(DEV_AAGUID)
apple_attestation_root_ca = """-----BEGIN CERTIFICATE-----
MIICITCCAaegAwIBAgIQC/O+DvHN0uD7jG5yH2IXmDAKBggqhkjOPQQDAzBSMSYw
JAYDVQQDDB1BcHBsZSBBcHAgQXR0ZXN0YXRpb24gUm9vdCBDQTETMBEGA1UECgwK
QXBwbGUgSW5jLjETMBEGA1UECAwKQ2FsaWZvcm5pYTAeFw0yMDAzMTgxODMyNTNa
Fw00NTAzMTUwMDAwMDBaMFIxJjAkBgNVBAMMHUFwcGxlIEFwcCBBdHRlc3RhdGlv
biBSb290IENBMRMwEQYDVQQKDApBcHBsZSBJbmMuMRMwEQYDVQQIDApDYWxpZm9y
bmlhMHYwEAYHKoZIzj0CAQYFK4EEACIDYgAERTHhmLW07ATaFQIEVwTtT4dyctdh
NbJhFs/Ii2FdCgAHGbpphY3+d8qjuDngIN3WVhQUBHAoMeQ/cLiP1sOUtgjqK9au
Yen1mMEvRq9Sk3Jm5X8U62H+xTD3FE9TgS41o0IwQDAPBgNVHRMBAf8EBTADAQH/
MB0GA1UdDgQWBBSskRBTM72+aEH/pwyp5frq5eWKoTAOBgNVHQ8BAf8EBAMCAQYw
CgYIKoZIzj0EAwMDaAAwZQIwQgFGnByvsiVbpTKwSga0kP0e8EeDS4+sQmTvb7vn
53O5+FRXgeLhpJ06ysC5PrOyAjEAp5U4xDgEgllF7En3VcE3iexZZtKeYnpqtijV
oyFraWVIyd/dganmrduC1bmTBGwD
-----END CERTIFICATE-----""".encode()
# Create X509 certificate object
root_cert = x509.load_pem_x509_certificate(
apple_attestation_root_ca, default_backend()
)
def get_public_key(parsed_data: dict) -> str:
"""Get the public key from the attestation"""
x5c = parsed_data.get('attStmt', {}).get('x5c', [])
if not x5c:
raise ValueError("No x5c found in attestation")
cred_cert = x509.load_der_x509_certificate(x5c[0])
public_key = cred_cert.public_key()
public_bytes = public_key.public_bytes(
encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint
)
public_bytes = _sha256(public_bytes)
key_hash_b64 = base64.b64encode(public_bytes).decode()
return key_hash_b64
def _validate_certificate_chain(parsed_data: dict) -> bool:
"""Step 1: Validate the certificate chain"""
x5c = parsed_data.get('attStmt', {}).get('x5c', [])
if not x5c:
return False
certs = [
x509.load_der_x509_certificate(cert_data)
for cert_data in x5c
]
for i in range(len(certs) - 1):
certs[i].verify_directly_issued_by(certs[i + 1])
certs[-1].verify_directly_issued_by(root_cert)
return True
def _validate_nonce(parsed_data: dict, challenge: str) -> bool:
"""Step 2: Verify nonce
- compute client data hash
- compute nonce
- make sure nonce is valid
"""
cert = x509.load_der_x509_certificate(
parsed_data.get('attStmt', {}).get('x5c', [])[0],
default_backend()
)
# It's required to hash the challenge first and append it to the auth data
# before hashing that entire thing. The example given by apple might lead you
# not to do this, but in practice we need this.
hashed_challenge = _sha256(challenge.encode())
# hashed_challenge = challenge.encode()
client_data_hash = parsed_data['authData']
expected_nonce = _sha256(client_data_hash + hashed_challenge)
ext = cert.extensions.get_extension_for_oid(OID_APPLE)
ext_nonce = ext.value.value[6:]
if not bytes_eq(expected_nonce, ext_nonce):
return False
return True
def _validate_key_consistency(parsed_data: dict, key_id: str) -> bool:
"""Step 3: Verify key consistency
- Compute keyIdentifier = SHA256(public key)
- Ensure that keyIdentifier matches the value sent by the app
"""
key_hash_b64 = get_public_key(parsed_data)
return hmac.compare_digest(key_hash_b64, key_id)
def _validate_rp_id(parsed_data: dict, app_id: str) -> bool:
"""Step 4: Verify RP id
- Compute expected_rp_id_hash = SHA256(App ID)
- Extract rpIdHash from authData and ensure it matches expected_rp_id_hash
"""
auth_data_rp_id_hash = parsed_data['authData'][:32]
expected_rp_id_hash = _sha256(app_id.encode())
return hmac.compare_digest(auth_data_rp_id_hash, expected_rp_id_hash)
def _validate_counter(parsed_data: dict) -> bool:
"""Step 5: Verify counter to prevent replay attacks
- Counter should be 0 at first attestation
- Counter should be monotonically increasing for subsequent attestations
"""
counter_bytes = parsed_data['authData'][33:37]
counter = int.from_bytes(counter_bytes, 'big')
# For first attestation, counter should be 0
if counter != 0:
return False
return True
def _validate_environment(parsed_data: dict) -> bool:
"""Step 6: Verify environment
- Check aaguid field in authData:
- aaguid should be:
- "appattestdevelop" for development
- "appattest000000000000" for production
- Ensure that the attestation environment matches the expected one
"""
aaguid = parsed_data['authData'][37:53]
aaguid_hex = aaguid.hex()
if INTEGRITY_ENVIRONMENT == "development":
expected_aaguid = DEV_AAGUID_HEX
else:
expected_aaguid = PROD_AAGUID_HEX
return hmac.compare_digest(aaguid_hex, expected_aaguid)
def _validate_auth_data_credential_id(parsed_data: dict, key_id: str) -> bool:
"""Step 7: Verify credential ID
- Extract credentialID from authData and ensure it matches the value sent by the app
"""
credential_id_length = int.from_bytes(
parsed_data['authData'][53:55], 'big')
credential_id = parsed_data['authData'][55:55 + credential_id_length]
credential_id_b64 = base64.b64encode(credential_id).decode()
return hmac.compare_digest(credential_id_b64, key_id)
def validate_attestation(
attestation: dict,
challenge: str,
key_id: str,
app_id: str,
) -> bool:
"""Main validation method that orchestrates all validation steps"""
try:
validation_steps = [
lambda: _validate_certificate_chain(attestation),
lambda: _validate_nonce(attestation, challenge),
lambda: _validate_rp_id(attestation, app_id),
lambda: _validate_key_consistency(attestation, key_id),
lambda: _validate_environment(attestation),
lambda: _validate_counter(attestation),
lambda: _validate_auth_data_credential_id(attestation, key_id),
]
for step in validation_steps:
result = step()
if not result:
return False
return True
except Exception:
return False
import json
import hashlib
import base64
import cbor2
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cache import KEY_CHALLENGE_PREFIX, KEY_COUNTER_PREFIX, KEY_PUBLIC_KEY_PREFIX
from redis import Redis
def _verify_nonce(assertion: dict, client_data: dict, public_key: str) -> bool:
'''
Step 1: Use the public key that you store from the attestation object to
verify that the assertion signature is valid for nonce.
'''
# Create nonce from authenticator data and client data
client_data_hash = hashlib.sha256(
json.dumps(client_data, separators=(',', ':')).encode()).digest()
nonce = hashlib.sha256(
(assertion['authenticatorData'] + client_data_hash)
).digest()
# Load the stored public key
public_key_bytes = base64.b64decode(public_key)
public_key_obj = serialization.load_der_public_key(public_key_bytes)
# Verify the signature
signature = assertion['signature']
public_key_obj.verify(
signature,
nonce,
ec.ECDSA(hashes.SHA256())
)
return True
def _verify_rp_id(assertion: dict, app_id: str) -> bool:
'''
Step 2: Compute the SHA256 hash of the clients App ID, and verify
that it matches the RP ID in the authenticator data.
'''
app_id_hash = hashlib.sha256(app_id.encode()).hexdigest()
rp_id = assertion['authenticatorData'][:32].hex()
return app_id_hash == rp_id
def validate_assertion(
redis_client: Redis,
client_data: dict,
assertion: str,
key_id: str,
app_id: str,
) -> bool:
"""Main validation method that orchestrates all validation steps"""
try:
# Batch the three Redis GET operations into a single MGET call
keys = [
f"{KEY_PUBLIC_KEY_PREFIX}{key_id}",
f"{KEY_CHALLENGE_PREFIX}{key_id}",
f"{KEY_COUNTER_PREFIX}{key_id}"
]
public_key, last_challenge, counter = redis_client.mget(keys)
counter = int(counter) if counter else 0
assertion = cbor2.loads(base64.b64decode(assertion))
assertion_count = int.from_bytes(assertion['authenticatorData']
[32:], 'big')
checks = [
_verify_nonce(assertion, client_data, public_key),
_verify_rp_id(assertion, f"{app_id}"),
assertion_count > counter,
client_data['challenge'] == last_challenge
]
passed = all(checks)
redis_client.set(f"{KEY_COUNTER_PREFIX}{key_id}", f'{assertion_count}')
return passed
except Exception as e:
print(f"Error in validate_assertion: {e}")
return False