The curious case of stateless JSON Web Tokens

You decided not to retire and now your multi-tenant app needs a mobile counterpart. You go ahead and implement JWT for authentication, but now InfoSec is knocking at your door because JWT is stateless.

Meaning once it's issued, it's valid till it expires or your encryption key changes.
Now you have to maintain a revoked list and roll your keys, suddenly the joy of JWT is diminishing but we can fix that.

We'll be focusing on a Web API project, with MVC parts using cookies as an example. It will also function as our authorization server. This post covers creating your own authorization server.
The use of clients(web, mobile) and tenants(Contoso, Northwind Traders) is up to you.

The Problem

JWT is stateless, which is a nightmare if something happens to leak. Conventional methods for dealing with a leak are complex, hard to implement and painful to manage.

The Solution

Store the JWT in a Redis Cache, provide a pointer(key) to the client, implement checks against pointer hijacking, refresh the access token behind the scenes, implement advanced user session management.

The How

Config Setup

In your Web.config you should have these app settings, they make management easier.
key="Environment" value="DEV"
key="AuthUrl" value="https://localhost:44332"
key="AuthTokenAudienceId" value="afbe601022632561825b5af293bd302f"
key="AuthTokenAudienceSecret" value="Q04B8fA7ML3DkcYu7SAiu-HQUOMXCerIfh1LmfWN2bo"
key="AuthTokenTimeSpanMinutes" value="60"
key="RefreshTokenTimeSpanMinutes" value="1440"
key="AuthTokenAllowInsecureHttp" value="false"
key="PointerTokenCookieName" value="ExamplePointerToken"
key="ValidateIPAddress" value="true"
key="ValidateUserAgent" value="true"
key="EnableSingleSession" value="true"


Stack Exchange Redis Helper

This helper will come in handy with Azure Redis Cache.

   public static class StackExchangeRedisHelper
   {
    public static async Task<T> Get<T>(this IDatabase cache, string key)
    {
        return Deserialize<T>(await cache.StringGetAsync(key));
    }

    public static async Task<string> GetAsync(this IDatabase cache, string key)
    {
        return Deserialize<string>(await cache.StringGetAsync(key));
    }

    public static void Set(this IDatabase cache, string key, string value)
    {
        cache.StringSet(key, Serialize(value));
        cache.KeyExpire(key, DateTime.UtcNow.AddMinutes(30));
    }

    static byte[] Serialize(object o)
    {
        if (o == null){return null;}
        BinaryFormatter binaryFormatter = new BinaryFormatter();
        using (MemoryStream memoryStream = new MemoryStream())
        {
            binaryFormatter.Serialize(memoryStream, o);
            byte[] objectDataAsStream = memoryStream.ToArray();
            return objectDataAsStream;
        }
    }

    static T Deserialize<T>(byte[] stream)
    {
        BinaryFormatter binaryFormatter = new BinaryFormatter();
        if (stream == null)
            return default(T);

        using (MemoryStream memoryStream = new MemoryStream(stream))
        {
            T result = (T)binaryFormatter.Deserialize(memoryStream);
            return result;
        }
    }
}

General Helper

The method names are self explanatory.

public static class GeneralHelper
{
    public static string GetHash(string input)
    {
        HashAlgorithm hashAlgorithm = new SHA256CryptoServiceProvider();
        byte[] byteValue = System.Text.Encoding.UTF8.GetBytes(input);
        byte[] byteHash = hashAlgorithm.ComputeHash(byteValue);
        return Convert.ToBase64String(byteHash);
    }

    private static string GetRandomCryptoBase64String()
    {
        using (RandomNumberGenerator rng = new RNGCryptoServiceProvider())
        {
            byte[] tokenData = new byte[4];
            rng.GetBytes(tokenData);
            return Convert.ToBase64String(tokenData);
        }
    }

    private static short GetRandomCryptoSubstringLengthNumber()
    {
        using (RandomNumberGenerator rng = new RNGCryptoServiceProvider())
        {
            byte[] tokenData = new byte[32];
            rng.GetBytes(tokenData);
            var random = BitConverter.ToInt16(tokenData, 0);
            return random;
        }
    }

    private static string GetRandomCryptoString()
    {
        using (RandomNumberGenerator rng = new RNGCryptoServiceProvider())
        {
            byte[] tokenData = new byte[4096];
            rng.GetBytes(tokenData);

            var rand = BitConverter.ToInt64(tokenData, 0);
            var random = BitConverter.ToInt16(tokenData, 0);

            const Decimal OldRange = (Decimal)int.MaxValue - (Decimal)int.MinValue;
            var unixEpoch = new DateTime(1970, 1, 1);
            var y2k = new DateTime(2000,1, 1).AddDays(random);
            var dateTimeNow = DateTime.Now;

            var min = (Int64)dateTimeNow.Subtract(unixEpoch).TotalMilliseconds;
            var max = (Int64)dateTimeNow.Subtract(y2k).TotalMilliseconds;

            Decimal NewRange = max - min;
            Decimal NewValue = ((Decimal)rand - (Decimal)int.MinValue) / OldRange * NewRange + (Decimal)min;
            var val = NewValue.ToString().Replace("-", "").Replace(".", "");
            return val;

        }
    }

    public static string GenerateJWTRedisKey()
    {

        var str = "";
        for (int c = 0; c < 10; c++)
        {
            str = str + GetRandomCryptoString() + GetRandomCryptoBase64String();
        }

        short substringLen = 0;
        while (substringLen <= 127 || substringLen > 256)
        {
            substringLen = GetRandomCryptoSubstringLengthNumber();
        }
        return str.Substring(0, substringLen);           
    }

    public static string GetForwardedOrRemoteIPAddress()
    {
        string ipAddress = HttpContext.Current.Request.ServerVariables["HTTP_X_FORWARDED_FOR"];
        if (!string.IsNullOrEmpty(ipAddress))
        {
            ipAddress = ipAddress.Split(':')[0];
        }
        else
        {
            ipAddress = HttpContext.Current.Request.ServerVariables["REMOTE_ADDR"];
            if (ipAddress == "::1")
            {
                ipAddress = "127.0.0.1";
            }
        }
        return ipAddress;
    }

    public static string GetPrettyDate(DateTime d)
    {
        var now = DateTime.UtcNow;            
        TimeSpan s = now.Subtract(d);            
        int dayDiff = (int)s.TotalDays;   
        int secDiff = (int)s.TotalSeconds;
        if (dayDiff < 0 || dayDiff >= 367)
        {
            return null;
        }
        if (dayDiff == 0)
        {              
            if (secDiff < 60){return "just now";}     
            if (secDiff < 120){return "1 minute ago";}   
            if (secDiff < 3600)
            {
                return string.Format("{0} minutes ago",
                Math.Floor((double)secDiff / 60));
            }    
            if (secDiff < 7200){return "1 hour ago";}                
            if (secDiff < 86400)
            {
                return string.Format("{0} hours ago",
                Math.Floor((double)secDiff / 3600));
            }
        }      
        if (dayDiff == 1){return "yesterday";}
        if (dayDiff < 7){return string.Format("{0} days ago",dayDiff);}
        if (dayDiff < 31)
        {
            return string.Format("{0} weeks ago",
            Math.Ceiling((double)dayDiff / 7));
        }
        if (dayDiff < 365)
        {
            return string.Format("{0} months ago",
            Math.Ceiling((double)dayDiff / 30));
        }
        return null;
    }

Token Helper

This helper comes in handy for dealing with the pointer, JWT and other features.

Name for the cookie that will contain the pointer value.

  private static string pointerTokenCookieName = 
  ConfigurationManager.AppSettings["PointerTokenCookieName"];

Whether the user's IP address should be checked against the one stored in the JWT.

    public static bool ValidateIPAddress()
    {
        return Convert.ToBoolean(ConfigurationManager.AppSettings["ValidateIPAddress"]);
    }

Whether the user's user agent should be checked against the one stored in the JWT.

    public static bool ValidateUserAgent()
    {
        return Convert.ToBoolean(ConfigurationManager.AppSettings["ValidateUserAgent"]);
    }

Whether a user can only login from once device at a time, will kill any other session.

    public static bool EnableSingleSession()
    {
        return Convert.ToBoolean(ConfigurationManager.AppSettings["EnableSingleSession"]);
    }

Retrieves the JWT from the authorization server.

public async static Task<string> GetTokenFromAuthService(string email, string password, string tenant)
    {
 string ipAddress = GeneralHelper.GetForwardedOrRemoteIPAddress();
        var formContent = new FormUrlEncodedContent(new[]
{
new KeyValuePair<string, string>("grant_type", "password"),
new KeyValuePair<string, string>("username", email),
new KeyValuePair<string, string>("password", password),
new KeyValuePair<string, string>("client_id", tenant),
new KeyValuePair<string, string>("user_agent",  HttpContext.Current.Request.ServerVariables["HTTP_USER_AGENT"]),
new KeyValuePair<string, string>("ip_address", ipAddress)
});
        using (var httpClient = new HttpClient())
        {
            var result = await httpClient.PostAsync(ConfigurationManager.AppSettings["AuthUrl"] + "/oauth/token", formContent);
            if (result.StatusCode != System.Net.HttpStatusCode.OK)
            {
                return null;
            }
            return await result.Content.ReadAsStringAsync();
        }
  }

Adds the JWT to the Redis Cache.

public async static Task<string> AddTokenToRedis(string authJson, bool useRefreshTokenTimeSpan)
    {
        var authToken = JsonConvert.DeserializeObject<AuthServiceTokenResponse>(authJson);
        IDatabase cache = Connection.GetDatabase();
        var pointer = GeneralHelper.GenerateJWTRedisKey();
        authToken.pointer = pointer;
        authToken.LastAccessed = DateTime.UtcNow;
        authToken.useRefreshTokenTimeSpan = useRefreshTokenTimeSpan;
        cache.Set(pointer, JsonConvert.SerializeObject(authToken));
        var expires = TimeSpan.FromSeconds(authToken.expires_in);
        if (useRefreshTokenTimeSpan)
        { 
            expires = TimeSpan.FromMinutes(Convert.ToDouble(ConfigurationManager.AppSettings["RefreshTokenTimeSpanMinutes"]));

        }
        cache.KeyExpire(pointer, expires);
        await AddUserTokenToRedis(authToken.userId, pointer, expires);
        return pointer;
    }

Validates the JWT and sets the current user.

  public static async Task<bool> SetTokenForUser(string pointer)
    {
        var authToken = await GetTokenFromRedis(pointer);
        byte[] audienceSecret = TextEncodings.Base64Url.Decode(ConfigurationManager.AppSettings["AuthTokenAudienceSecret"]);
        var tokenHandler = new JwtSecurityTokenHandler();
        var validationParameters = new TokenValidationParameters()
        {
            ValidAudience = ConfigurationManager.AppSettings["AuthTokenAudienceId"],
            ValidIssuer = ConfigurationManager.AppSettings["AuthUrl"],
            IssuerSigningToken = new BinarySecretSecurityToken(audienceSecret)
        };

        SecurityToken securityToken;
        var principal = tokenHandler.ValidateToken(authToken.access_token, validationParameters, out securityToken);
        HttpContext.Current.User = principal;
        return true;
    }

Retrieves the JWT from Redis using the pointer

public async static Task<AuthServiceTokenResponse> GetTokenFromRedis(string pointer)
    {
        IDatabase cache = Connection.GetDatabase();
        var authJson = (string)await cache.GetAsync(pointer);
        if (string.IsNullOrEmpty(authJson))
        {
            return null;
        }
        var authToken = JsonConvert.DeserializeObject<AuthServiceTokenResponse>(authJson);
        authToken.LastAccessed = DateTime.UtcNow;
        cache.Set(pointer, JsonConvert.SerializeObject(authToken));

        if (authToken.useRefreshTokenTimeSpan)
        {
            cache.KeyExpire(pointer, TimeSpan.FromMinutes(Convert.ToDouble(ConfigurationManager.AppSettings["RefreshTokenTimeSpanMinutes"])));
        }
        else
        {
            cache.KeyExpire(pointer, TimeSpan.FromSeconds(authToken.expires_in));
        }
        return authToken;
    }

Removes the JWT from Redis using the pointer

    public static void ClearTokenFromRedis(string pointer)
    {
        IDatabase cache = Connection.GetDatabase();
        cache.Set(pointer, string.Empty);
        cache.KeyExpire(pointer, TimeSpan.FromSeconds(1));
    }

Checks if the JWT should be refreshed for a new access token.

public static async Task<bool> CheckTokenValidTo(string pointer)
    {
        var authToken = await GetTokenFromRedis(pointer);
        if (authToken == null)
        {
            return false;
        }
        var access_token = authToken.access_token;
        var handler = new JwtSecurityTokenHandler();
        var tokenS = handler.ReadToken(access_token) as JwtSecurityToken;
        var minutesLeft = (tokenS.ValidTo - DateTime.UtcNow).TotalMinutes;
        if (minutesLeft <= (Convert.ToInt32(ConfigurationManager.AppSettings["AuthTokenTimeSpanMinutes"]) - 15))
        {
            return false;
        }
        return true;
    }

Checks if the pointer has possibly been hijacked,

public static async Task<bool> CheckTokenAuthenticity(string pointer)
    {
        var authToken = await GetTokenFromRedis(pointer);
        if (authToken == null)
        {
            return false;
        }
        if (ValidateIPAddress())
        {
            if (authToken.ip_address != GeneralHelper.GetForwardedOrRemoteIPAddress())
            {
                return false;
            }
        }
        if (ValidateUserAgent())
        {
            if (authToken.user_agent != HttpContext.Current.Request.ServerVariables["HTTP_USER_AGENT"])
            {
                return false;
            }
        }
        return true;
    }

Refreshes the JWT for a new access token.

public async static Task<bool> RefreshToken(string pointer)
    {
        var token = await GetTokenFromRedis(pointer);
        if (token == null)
        {
            return false;
        }
        var formContent = new FormUrlEncodedContent(new[]
{
        new KeyValuePair<string, string>("grant_type", "refresh_token"),
        new KeyValuePair<string, string>("refresh_token", token.refresh_token),
            new KeyValuePair<string, string>("client_id", token.client_id),
new KeyValuePair<string, string>("user_agent", HttpContext.Current.Request.ServerVariables["HTTP_USER_AGENT"]),
new KeyValuePair<string, string>("ip_address", GeneralHelper.GetForwardedOrRemoteIPAddress())
    });
        var response = "";

        using (var httpClient = new HttpClient())
        {
            var result = await httpClient.PostAsync(ConfigurationManager.AppSettings["AuthUrl"] + "/oauth/token", formContent);
            if (result.StatusCode != System.Net.HttpStatusCode.OK)
            {
                return false;
            }
            response = await result.Content.ReadAsStringAsync();

        }
        var authToken = JsonConvert.DeserializeObject<AuthServiceTokenResponse>(response);
        IDatabase cache = Connection.GetDatabase();
        authToken.pointer = pointer;
        authToken.LastAccessed = token.LastAccessed;
        authToken.useRefreshTokenTimeSpan = token.useRefreshTokenTimeSpan;
        cache.Set(pointer, JsonConvert.SerializeObject(authToken));

        if (token.useRefreshTokenTimeSpan)
        {
            cache.KeyExpire(pointer, TimeSpan.FromMinutes(Convert.ToDouble(ConfigurationManager.AppSettings["RefreshTokenTimeSpanMinutes"])));
        }
        else
        {
            cache.KeyExpire(pointer, TimeSpan.FromSeconds(authToken.expires_in));
        }
        return true;
    }

Creates the cookie with the pointer value.

public static async Task<HttpCookie> GetPointerCookie(string pointer)
    {
        var authToken = await GetTokenFromRedis(pointer);
        HttpCookie cookie = new HttpCookie(pointerTokenCookieName);
        cookie.Value = pointer;
        if (authToken.useRefreshTokenTimeSpan)
        {
            cookie.Expires = DateTime.UtcNow.AddMinutes(Convert.ToDouble(ConfigurationManager.AppSettings["RefreshTokenTimeSpanMinutes"]));
        }
        else
        {
            cookie.Expires = DateTime.UtcNow.AddSeconds(Convert.ToDouble(authToken.expires_in));
        }
        cookie.HttpOnly = true;
        cookie.Secure = true;
        return cookie;
    }

Expires the cookie.

public static void ClearCookies()
    {
        HttpContext.Current.Request.Cookies[pointerTokenCookieName].Expires = DateTime.UtcNow.AddDays(-1);
        HttpContext.Current.Response.Cookies[pointerTokenCookieName].Expires = DateTime.UtcNow.AddDays(-1);
        HttpContext.Current.User = null;
    }

Retrieves all the pointers for a user from Redis.

public static async Task<RedisUserPointers> GetUserPointersFromRedis(string userId)
    {
        IDatabase cache = Connection.GetDatabase();
        var userPointersJson = (string) await cache.GetAsync(userId);
        if (string.IsNullOrEmpty(userPointersJson))
        {
            return new RedisUserPointers() { Pointers = new List<string>() };
        }
        return JsonConvert.DeserializeObject<RedisUserPointers>(userPointersJson);
    }

Adds the pointer to a user's JWT to Redis.

public static async Task AddUserPointerToRedis(string userId, string pointer, TimeSpan expires)
    {
        IDatabase cache = Connection.GetDatabase();
        var userPointers = await GetUserPointersFromRedis(userId);
        userPointers.UserId = userId;
        var newPointer = userPointers.Pointers.SingleOrDefault(x => x == pointer);
        if (newPointer == null)
        {
            if (EnableSingleSession())
            {
                foreach (var t in userPointers.Pointers)
                {
                    ClearTokenFromRedis(t);
                }
                userPointers.Pointers = new List<string>();
            }
            userPointers.Pointers.Add(pointer);
            cache.Set(userId, JsonConvert.SerializeObject(userPointers));
            cache.KeyExpire(userId, expires);
        }
        await ClearExpiredUserPointersFromRedis(userId); 
    }

Removes all expired JWTs for a user.

    public static async Task ClearExpiredUserTokensFromRedis(string userId)
    {
        IDatabase cache = Connection.GetDatabase();
        var userPointersJson = (string)await cache.GetAsync(userId);
        var userPointers = new RedisUserPointers() { Pointers = new List<string>(), UserId=userId };
        var useRefreshTokenTimeSpan = false;
        if (string.IsNullOrEmpty(userPointersJson))
        {
            cache.Set(userId, JsonConvert.SerializeObject(userPointers));

        }else
        {
            userPointers = JsonConvert.DeserializeObject<RedisUserPointers>(userPointersJson);
            var pointersToRemove = new List<string>();
            foreach (var t in userPointers.Pointers)
            {
                var oldToken = (string)await cache.GetAsync(t);
                if (string.IsNullOrEmpty(oldToken))
                {
                    pointersToRemove.Add(t);
                }else
                {
                    var authToken = JsonConvert.DeserializeObject<AuthServiceTokenResponse>(oldToken);
                    if (authToken.useRefreshTokenTimeSpan)
                    {
                        useRefreshTokenTimeSpan = true;
                    }
                }
            }
            foreach (var t in pointersToRemove)
            {
                userPointers.Pointers.Remove(t);
            }
        }
        var expires = TimeSpan.FromMinutes(Convert.ToDouble(ConfigurationManager.AppSettings["AuthTokenTimeSpanMinutes"]));            
        if (useRefreshTokenTimeSpan)
        {
            expires = TimeSpan.FromMinutes(Convert.ToDouble(ConfigurationManager.AppSettings["RefreshTokenTimeSpanMinutes"]));

        }

        cache.Set(userId, JsonConvert.SerializeObject(userPointers));
        cache.KeyExpire(userId, expires);
    }

Models

public class AuthServiceTokenResponse
{
    public string client_id { get; set; }
    public string pointer { get; set; }
    public string access_token { get; set; }
    public string token_type { get; set; }
    public int expires_in { get; set; }
    public string refresh_token { get; set; }
    public string userName { get; set; }
    public string ip_address { get; set; }
    public string user_agent { get; set; }
    public string userId { get; set; }
    public DateTime LastAccessed { get; set; }
    public bool useRefreshTokenTimeSpan { get; set; }
}

public class RedisUserPointers
{
    public string UserId { get; set; }
    public List<string> Pointers { get; set; }
}

Putting It Together

Your Login action in the Account controller should look like below.

    [HttpPost]
    [AllowAnonymous]
    [ValidateAntiForgeryToken]
    public async Task<ActionResult> Login(LoginViewModel model, string returnUrl){
        if (!ModelState.IsValid)
        {
            return View(model);
        }
        var token = await TokenHelper.GetTokenFromAuthService(model.Email, 
        model.Password, model.Tenant);
        var pointer = await TokenHelper.AddTokenToRedis(token, model.RememberMe);
        var setTokenForUserResult = TokenHelper.SetTokenForUser(pointer);
        var pointerCookie = await TokenHelper.GetPointerCookie(pointer);
        Response.Cookies.Add(pointerCookie);
        return RedirectToLocal(returnUrl); 
    }

Your LogOff action in the Account controller should look like below.

    [AllowAnonymous]
    [HttpPost]
    public ActionResult LogOff(){        TokenHelper.ClearTokenFromRedis(Request.Cookies[ConfigurationManager.AppSettings["PointerTokenCookieName"]].Value);
        TokenHelper.ClearCookies();
        Session.Abandon();
 }

In your Global.asax.cs file, add a constructor if there isn't one. In the constructor add the following code.

        var wrapper = new EventHandlerTaskAsyncHelper(CheckAuth);
        this.AddOnAcquireRequestStateAsync(wrapper.BeginEventHandler, wrapper.EndEventHandler);

Then add the CheckAuth method.

private async Task CheckAuth(object sender, EventArgs e)
    {
        var app = (HttpApplication)sender;
        var ctx = app.Context;

        if (Request.RawUrl.Contains("__browserLink"))
        {
            return;
        }
        if (Request.RawUrl.Contains("LogOff"))
        {                Response.Cookies[ConfigurationManager.AppSettings["PointerTokenCookieName"]].Expires = DateTime.UtcNow.AddDays(-1);
            return;
        }
        var pointerCookie = Request.Cookies[ConfigurationManager.AppSettings["PointerTokenCookieName"]];
        if (pointerCookie == null)
        {
            return;
        }
        var tokenAuthenticity = await TokenHelper.CheckTokenAuthenticity(pointerCookie.Value);
        if (!tokenAuthenticity)
        {
            return;
        }
        var tokenCheck =await  TokenHelper.CheckTokenValidTo(pointerCookie.Value);
        if (!tokenCheck)
        {
            var refreshedToken = await TokenHelper.RefreshToken(pointerCookie.Value);
            var newPointerCookie = await TokenHelper.GetPointerCookie(pointerCookie.Value);
            Response.Cookies.Add(newPointerCookie);
        }

        await TokenHelper.SetTokenForUser(pointerCookie.Value);            
        return;
    }

Real World Use

Active User Sessions The client doesn't know the difference between a JWT and pointer so this can be implemented effortlessly. Also since the JWT is being utilized server side, there is no risk of a leak.

By fixing the problems that come with JWT, we can easily add user session management. Monitor activate user sessions or give users an overview of their activate sessions with the ability to kill it.
An example project is on GitHub.