feat: reject token if it's not valid

This commit is contained in:
2024-06-15 06:26:13 -04:00
parent dfd709ad1c
commit 5d528fba52
12 changed files with 368 additions and 16 deletions

View File

@@ -1,7 +1,7 @@
import { eq } from "drizzle-orm";
import { DateTime } from "luxon";
import { beforeEach, describe, expect, it } from "bun:test";
import { beforeEach, describe, expect, it, mock } from "bun:test";
import app from "~/index";
import { getTestDb } from "~/libs/test/getTestDb";
@@ -16,6 +16,9 @@ describe("requests the /token route", () => {
beforeEach(async () => {
await resetTestDb();
mock.module("src/libs/fcm/verifyFcmToken", () => ({
verifyFcmToken: () => true,
}));
});
it("should succeed", async () => {
@@ -191,4 +194,42 @@ describe("requests the /token route", () => {
expect(row).toBeUndefined();
});
it("token is invalid, should fail", async () => {
mock.module("src/libs/fcm/verifyFcmToken", () => ({
verifyFcmToken: () => false,
}));
const res = await app.request("/token", {
method: "POST",
headers: new Headers({
"Content-Type": "application/json",
}),
body: JSON.stringify({ token: "123", deviceId: "124", username: null }),
});
expect(res.json()).resolves.toEqual({ success: false });
expect(res.status).toBe(401);
});
it("token is invalid, should not insert new entry", async () => {
mock.module("src/libs/fcm/verifyFcmToken", () => ({
verifyFcmToken: () => false,
}));
await app.request("/token", {
method: "POST",
headers: new Headers({
"Content-Type": "application/json",
}),
body: JSON.stringify({ token: "123", deviceId: "124", username: null }),
});
const row = await db
.select()
.from(deviceTokensTable)
.where(eq(deviceTokensTable.deviceId, "124"))
.get();
expect(row).toBeUndefined();
});
});

View File

@@ -1,6 +1,11 @@
import { OpenAPIHono, createRoute, z } from "@hono/zod-openapi";
import { env } from "hono/adapter";
import mapKeys from "lodash.mapkeys";
import { Case, changeStringCase } from "~/libs/changeStringCase";
import type { AdminSdkCredentials } from "~/libs/fcm/getGoogleAuthToken";
import { verifyFcmToken } from "~/libs/fcm/verifyFcmToken";
import { readEnvVariable } from "~/libs/readEnvVariable";
import { saveToken } from "~/models/token";
import type { Env } from "~/types/env";
import {
@@ -51,8 +56,20 @@ app.openapi(route, async (c) => {
await c.req.json<typeof SaveTokenRequest._type>();
try {
const isValidToken = await verifyFcmToken(
token,
mapKeys(
readEnvVariable<AdminSdkCredentials>(c.env, "ADMIN_SDK_JSON"),
(_, key) => changeStringCase(key, Case.snake_case, Case.camelCase),
) as unknown as AdminSdkCredentials,
);
if (!isValidToken) {
return c.json(ErrorResponse, 401);
}
await saveToken(env(c, "workerd"), deviceId, token, username);
} catch (error) {
// when token already exists in the database
if (
error.code === "SQLITE_CONSTRAINT" &&
error.message.includes("device_tokens.token")

View File

@@ -0,0 +1,29 @@
import { describe, expect, it } from "bun:test";
import { Case, changeStringCase } from "./changeStringCase";
describe("changeStringCase", () => {
it("from camelCase to snake_case", () => {
expect(
changeStringCase("camelCase", Case.camelCase, Case.snake_case),
).toEqual("camel_case");
});
it("from snake_case to camelCase", () => {
expect(
changeStringCase("snake_case", Case.snake_case, Case.camelCase),
).toEqual("snakeCase");
});
it("from camelCase to camelCase", () => {
expect(
changeStringCase("camelCase", Case.camelCase, Case.camelCase),
).toEqual("camelCase");
});
it("from snake_case to snake_case", () => {
expect(
changeStringCase("snake_case", Case.snake_case, Case.snake_case),
).toEqual("snake_case");
});
});

View File

@@ -0,0 +1,34 @@
export enum Case {
camelCase,
snake_case,
}
export function changeStringCase<T>(
str: string,
currentCase: Case,
newCase: Case,
): string {
if (currentCase === newCase) {
return str;
}
const currentSeparator = currentCase === Case.snake_case ? "_" : /(?=[A-Z])/;
const words = str.split(currentSeparator);
return words
.map((word, index) => {
switch (newCase) {
case Case.camelCase:
if (index === 0) {
return word.toLowerCase();
}
return word[0].toUpperCase() + word.slice(1).toLowerCase();
case Case.snake_case:
return word.toLowerCase();
default:
throw new Error("Unknown case");
}
})
.join(newCase === Case.snake_case ? "_" : "");
}

View File

@@ -1,26 +1,171 @@
import { GoogleToken } from "gtoken";
import type { GetTokenOptions, TokenData as GoogleTokenData } from "gtoken";
import { SignJWT, importPKCS8 } from "jose";
import { DateTime } from "luxon";
import { lazy } from "../lazy";
export async function getGoogleAuthToken(adminSdkJson: AdminSdkCredentials) {
const { privateKey, clientEmail } = adminSdkJson;
const gToken = new GoogleToken({
key: privateKey,
email: clientEmail,
scope: ["https://www.googleapis.com/auth/firebase.messaging"],
});
const gToken = new GoogleToken(
{
key: privateKey,
email: clientEmail,
scope: ["https://www.googleapis.com/auth/firebase.messaging"],
},
adminSdkJson,
);
return gToken.getToken().then((token) => token.access_token);
}
const GOOGLE_TOKEN_URL = "https://www.googleapis.com/oauth2/v4/token";
class GoogleToken {
#inFlightRequest?: Promise<GoogleTokenData>;
#tokenData?: TokenData;
#scope = lazy(() => {
const scope = this.options.scope;
if (Array.isArray(scope)) {
return scope;
} else if (typeof scope === "string") {
return [scope];
} else {
return [];
}
});
constructor(
private options: TokenOptions,
credentialsJson: AdminSdkCredentials,
) {
this.options.key = this.options.key ?? credentialsJson.privateKey;
this.options.email = this.options.email ?? credentialsJson.clientEmail;
}
getToken({
forceRefresh = false,
...options
}: GetTokenOptions = {}): Promise<GoogleTokenData> {
return this.#getTokenAsync({ ...options, forceRefresh });
}
isTokenExpiring(): boolean {
const now = DateTime.now();
const eagerRefreshThresholdMillis =
this.options.eagerRefreshThresholdMillis ?? 0;
if (this.#tokenData) {
return (
now.plus({ milliseconds: eagerRefreshThresholdMillis }) >=
this.#tokenData.expiresAt
);
}
return true;
}
#getTokenAsync(options: GetTokenOptions): Promise<GoogleTokenData> {
const { forceRefresh } = options;
if (this.#inFlightRequest && !forceRefresh) {
return this.#inFlightRequest;
}
try {
this.#inFlightRequest = this.#getTokenAsyncInternal(options);
return this.#inFlightRequest!;
} finally {
this.#inFlightRequest = undefined;
}
}
#getTokenAsyncInternal(options: GetTokenOptions): Promise<GoogleTokenData> {
if (!this.isTokenExpiring() && options.forceRefresh) {
return Promise.resolve(this.#tokenData!.token);
}
return this.#requestToken();
}
async #requestToken(): Promise<GoogleTokenData> {
if (!this.options.email) {
throw new Error("No email provided");
}
if (!this.options.key) {
throw new Error("No private key provided");
}
const issuedTokenAt = DateTime.now().toSeconds();
const additionalClaims = this.options.additionalClaims ?? {};
const jwtPayload = {
iss: this.options.email,
scope: this.#scope.get().join(" "),
aud: GOOGLE_TOKEN_URL,
exp: issuedTokenAt + 3600,
iat: issuedTokenAt,
additionalClaims,
sub: this.options.sub,
};
const key = await importPKCS8(this.options.key, "RS256");
const signedJwt = await new SignJWT(jwtPayload)
.setProtectedHeader({ alg: "RS256" })
.sign(key);
try {
const res = await fetch(GOOGLE_TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: new URLSearchParams({
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: signedJwt,
}),
});
this.#tokenData = await res.json<GoogleTokenData>().then((data) => ({
token: data,
expiresAt: DateTime.fromSeconds(issuedTokenAt).plus({
seconds: data.expires_in,
}),
}));
return this.#tokenData!.token;
} catch (e) {
console.error(e);
throw e;
}
}
}
export interface TokenOptions {
keyFile?: string;
key?: string;
email?: string;
iss?: string;
sub?: string;
scope?: string | string[];
additionalClaims?: {};
/** Eagerly refresh the token if it is within this many milliseconds from expiring. Defaults to 0. */
eagerRefreshThresholdMillis?: number;
}
interface TokenData {
token: GoogleTokenData;
expiresAt: DateTime;
}
export interface AdminSdkCredentials {
type: string;
projectID: string;
privateKeyID: string;
projectId: string;
privateKeyId: string;
privateKey: string;
clientEmail: string;
clientID: string;
authURI: string;
tokenURI: string;
authProviderX509CERTURL: string;
clientX509CERTURL: string;
authProviderX509CertUrl: string;
clientX509CertUrl: string;
universeDomain: string;
}

View File

@@ -9,7 +9,7 @@ export async function sendFcmMessage(
isOnlyValidatingFcmMessage?: boolean,
) {
return fetch(
`https://fcm.googleapis.com/v1/projects/${adminSdkJson.projectID}/messages:send`,
`https://fcm.googleapis.com/v1/projects/${adminSdkJson.projectId}/messages:send`,
{
method: "POST",
body: JSON.stringify({

45
src/libs/lazy.spec.ts Normal file
View File

@@ -0,0 +1,45 @@
import { describe, expect, it } from "bun:test";
import { lazy } from "./lazy";
describe("lazy", () => {
it("lazy value returned when get is called", () => {
const value = lazy(() => "value");
expect(value.get()).toBe("value");
});
it("lazy function not called if get isn't called", () => {
let setValue = false;
lazy(() => {
setValue = true;
return "value";
});
expect(setValue).toBeFalse();
});
it("lazy function called if get is called", () => {
let setValue = false;
lazy(() => {
setValue = true;
return "value";
}).get();
expect(setValue).toBeTrue();
});
it("lazy function called only once if get is called multiple times", () => {
let count = 0;
const value = lazy(() => {
count++;
return "value";
});
const NUM_TIMES_CALLED = 1_000_000;
for (let i = 0; i < NUM_TIMES_CALLED; i++) {
value.get();
}
expect(count).toBe(1);
});
});

24
src/libs/lazy.ts Normal file
View File

@@ -0,0 +1,24 @@
export interface Lazy<T> {
get: () => T;
}
class LazyImpl<T> {
#value?: T;
constructor(private fn: () => T) {}
get(): T {
let value = this.#value;
if (value) {
return value;
}
value = this.fn();
this.#value = value;
return value;
}
}
export function lazy<T>(fn: () => T): Lazy<T> {
return new LazyImpl(fn);
}

View File

@@ -2,6 +2,8 @@ import type { TokenOptions } from "gtoken";
import { mock } from "bun:test";
import type { AdminSdkCredentials } from "~/libs/fcm/getGoogleAuthToken";
const emailRegex =
/^(([^<>()[\]\\.,;:\s@"]+(\.[^<>()[\]\\.,;:\s@"]+)*)|.(".+"))@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\])|(([a-zA-Z\-0-9]+\.)+[a-zA-Z]{2,}))$/;
@@ -27,4 +29,12 @@ class MockGoogleToken {
}
}
mock.module("gtoken", () => ({ GoogleToken: MockGoogleToken }));
mock.module("src/libs/fcm/getGoogleAuthToken", () => {
return {
getGoogleAuthToken: (adminSdkJson: AdminSdkCredentials) => {
return new MockGoogleToken({
email: adminSdkJson.clientEmail,
}).getToken();
},
};
});