diff --git a/salty.c b/salty.c index 5e20cf7..1341199 100644 --- a/salty.c +++ b/salty.c @@ -33,7 +33,7 @@ bool keyGen(const unsigned char* salt, const char* password, unsigned char* key) return true; } -int encryptMessage(const char* inputFile, const char* outputFile, const char* password) { +int encryptFile(const char* inputFile, const char* outputFile, const char* password) { // Generate salt unsigned char salt[SALT_SIZE]; @@ -177,7 +177,7 @@ int encryptMessage(const char* inputFile, const char* outputFile, const char* pa return 0; } -int decryptMessage(const char* inputFile, const char* outputFile, const char* password) { +int decryptFile(const char* inputFile, const char* outputFile, const char* password) { // Open output file FILE *outFile = fopen(outputFile,"wb"); if (outFile == NULL) { @@ -232,6 +232,12 @@ int decryptMessage(const char* inputFile, const char* outputFile, const char* pa // Get the size of the file and read it inLen = fileSize(inFile); unsigned char* input = malloc(inLen); + if (input == NULL) { + fprintf(stderr, ERR"Memory allocation error: File too large.\n"); + fclose(inFile); + fclose(outFile); + return 1; + } if (fread(input,1,inLen,inFile) != inLen) { fprintf(stderr,ERR"Error reading from input file.\n"); @@ -255,7 +261,13 @@ int decryptMessage(const char* inputFile, const char* outputFile, const char* pa fprintf(stderr,OK"Size of input file: %zu bytes.\n",inLen); fprintf(stderr,OK"Encrypted content is %zu bytes.\n",encLen); - unsigned char encrypted[encLen]; + unsigned char* encrypted = malloc(encLen); + if (encrypted == NULL) { + fprintf(stderr,ERR"Memory allocation error."); + free(fullInput); + fclose(outFile); + return 1; + } // Verify file size_t offset = 0; @@ -276,15 +288,17 @@ int decryptMessage(const char* inputFile, const char* outputFile, const char* pa memcpy(nonce, fullInput + offset, sizeof(nonce)); offset += sizeof(nonce); - memcpy(encrypted, fullInput + offset, sizeof(encrypted)); + memcpy(encrypted, fullInput + offset, encLen); fprintf(stderr,OK"Data retrieved.\n"); + free(fullInput); + // Key unsigned char key[KEY_SIZE]; if (!keyGen(salt,password,key)) { - free(fullInput); + free(encrypted); fclose(outFile); return 1; } @@ -292,24 +306,25 @@ int decryptMessage(const char* inputFile, const char* outputFile, const char* pa fprintf(stderr,OK"Proceeding to decrypt file...\n"); size_t decLen = encLen - crypto_secretbox_MACBYTES; - unsigned char decrypted[decLen]; + unsigned char* decrypted = malloc(decLen); if (crypto_secretbox_open_easy(decrypted,encrypted,encLen,nonce,key) < 0) { fprintf(stderr,ERR"Error decrypting file.\n"); - free(fullInput); + free(encrypted); fclose(outFile); return 1; } if (fwrite(decrypted,1,decLen,outFile) != decLen) { fprintf(stderr,ERR"Error writing data to file.\n"); - free(fullInput); + free(encrypted); fclose(outFile); return 1; } fclose(outFile); - free(fullInput); + free(encrypted); + free(decrypted); fprintf(stderr,"\n"OK"File decrypted!\n"); @@ -469,8 +484,8 @@ int main(int argc, char *argv[]) { } if (decrypt) { - return decryptMessage(input,output,password); + return decryptFile(input, output, password); } else { - return encryptMessage(input,output,password); + return encryptFile(input, output, password); } }