I was reading up on DES while waiting for a gigantic Perforce sync over VPN, and something struck me as interesting: DES is a Feistel network, and so is any number of other moden ciphers. This suggests abstraction: can we write a generic Feistel network, and then implement a variety of ciphers in terms of that? This post is basically the result of posing that question. I wanted to cover both DES and AES in a single post, but just presenting DES took a fair bit of text, and then I realized that AES is not Feistel. The post does, however, present a complete implementation of DES and 3TDES, and I’d like to follow it up with other ciphers later on.

As a word of caution, literally all I know about cryptography, DES, Feistel ciphers, or anything else in that field comes from about two days’ worth of reading up on it while waiting for builds and syncs. Feistel networks rely on a cryptographically-strong pseudorandom function. An incorrect implementation of that function may let you encrypt and decrypt plaintext, but do so in a way that’s potentially cryptographically useless. If you use any of the code here (or in any other of my posts, for that matter) for anything remotely sensitive, please make sure that you verify it against a description of the algorithm that is known to be correct.

A Feistel network is a block cipher: that is, it’s a cipher that acts on fixed-length blocks. It’s described by the number of *rounds* (iterations), a set of *subkeys*, one for each round (also called a *key schedule*), and two functions: *(+)* and *f*. The former is commonly bitwise addition modulo 2 (exclusive OR, in other words), while *f* is the so-called *round function*, which we’ll talk about shortly. Thus specified, the resulting Feistel network is a process that takes a block, splits it into two equal halves, call them *L_0* and *R_0*, and proceeds to apply the following iterations to it, one per round:

If *n* is the number of rounds, then is the final ciphertext, which we’ll usually merge back into a single number. The process can be interpreted as follows: at each iteration, the right half *R* crosses into the left half *L*, and the left half *L* is scrambled using *(+)* and *f*, and crossed into *R*. One half is always “more” encrypted than the other, by one round. The idea is that the round function *f* is something weird and non-invertible — to be precise, *f* is a cryptographically secure pseudorandom function with as the seed.

Michael Luby and Charles Rackoff showed that if this is the case, then four rounds are sufficient to make the corresponding Feistel network a *strong* pseudorandom permutation, meaning that it remains pseudorandom even if the inverse permutation is discovered. This is a good property to have, since otherwise you can’t publish the algorithm; besides, cryptographers appear to be fond of providing hypothetical adversaries with access to an omniscient oracle who knows the details of the algorithm. After four rounds, the oracle doesn’t help. Plain DES uses 16 rounds. Triple-DES uses 48.

The ability to decrypt a ciphertext hinges on the definition of (+), in that we need to be able to reverse the process as follows:

The reason this works is that *(+)* is picked so that if one of the parameters is held constant, the function is its own inverse — so we’re backtracking from the last subkey back to the first. Exclusive OR works here, but we could pick other functions as well (although, it seems that anything more complex than XOR could as easily be absorbed into f instead — do Feistel ciphers exist that have a mixing function other than XOR?).

Alternatively, we can reverse the process by reversing the key schedule *K*, swapping *L* and *R*, and using the same network as we did for encryption. DES swaps the *L* and *R* halves before joining them into the final ciphertext, so it can be inverted by simply reversing the key schedule. Any generic description of a Feistel cipher is going to be a higher-order function, since it takes the functions *(+)* and *f* as arguments. Let’s write it:

feistel :: (Bits a) => (a -> b -> a) -- Mixing function -> (a -> k -> b) -- Round function -> a -- Block to be encrypted -> [k] -- Key schedule -> (a, a) -- (L, R) feistel (+) f block keys = foldl rnd (l, r) keyswhere(l, r) = split block half half = (bitSize block) `div` 2 rnd (l, r) k = (r, l + f r k) split n k = (n .&. (2^k-1), n `shiftR` k) merge l r k = (r `shiftL` k) .|. l

As a side note, in C++, I would write the above as an abstract base class with pure virtuals for *(+)* and *f*. That sort of translation seems to crop up fairly often.

The above is hopefully a fairly straightforward implementation of the verbal description. We want to parametrize our Feistel implementation over all fixed-width types with bitwise operations on them, so we require that the type of a data block is in *Bits*. We then ask for the mixing function *(+)*, the round function *f*, the block itself, and the key schedule. We split the block in half, and apply the iterations as described earlier (note that we infer the number of rounds from the key schedule). Since each round takes the results of the previous round, a straightforward way to implement the process is to describe a single round, and then left-fold the key schedule over it.

Finally, to make the bit fiddling easier for subsequent applications, we generalize splitting a block into halves and merging it back, by writing the functions split and merge.

A Feistel cipher is, more generally, a type of a product cipher. Product ciphers are block ciphers that execute, in sequence, a series of relatively simple transformation of the plaintext block. Commonly, these transformations include bitwise permutations (P-boxes), substitutions (S-boxes), and linear mixing (our *(+)* function). In the case of DES, there is a handful of P-boxes, 8 S-boxes, and XOR for mixing.

A permutation box is simply a bitwise permutation: we shuffle the bits around according to some table. Most of the permutation boxes in DES are invertible, but some are not. Let’s write the code for applying a permutation box in general.

permute :: (Num a, Bits a) => [Int] -> a -> a permute table key = foldl shuffle 0 (zip table [0..])whereshuffle k (n+1, b) = k .|. (isSet n `shiftL` b) isSet n =iftestBit key nthen1else0

Note the *(n+k)* pattern in *shuffle*: the DES documentation I’ve read uses the convention that LSB is bit 1, not bit 0, and I didn’t want to have to convert each table. This is fairly straightforward as well, but if you don’t read Haskell, the idea is as follows. We take a table of bit positions, presented as a list: so [4,2,1] means that bit 1 is shuffled into position 4, bit 2 remains the same, and bit 3 is moved to position 1. We decorate that table with the bit positions, by saying *zip table [0..]*, which evaluates to a list of pairs *(destination_bit, source_bit)*. Finally, we run over that list, and set bits as appropriate.

The above is almost everything we need to implement DES, and we haven’t even discussed the algorithm. I’ll use the implementation of single-block encryption as a way of introducing the process.

des keys block = applyFp (merge' $ feistel xor f (applyIp block) keys)wheref r key =letbs = take 8 $ unfoldr (Just . shift) nr nr = xor (ebitSelect r) key shift k = (k .&. 0x3f, k `shiftR` 6)inpPerm (applySboxes bs) merge' (l,r) = merge r l 32

Let’s go through this step by step. To get some of the undefined functions out of the way, *applyIp* and *applyFp* are P-boxes, where Ip stands for *Initial Permutation*, and Fp stands for *Final Permutation* (also called *Inverse Initial Permutation*). Similarly, *ebitSelect* and *pPerm* are P-boxes which are applied at various stages of computation of *f*. Finally, *applySboxes* applies, as the name suggests, the S-boxes. The details of box implementations mostly consist of table data, so we’ll concentrate on the algorithm first.

First, *feistel* is evaluated, with xor as the mixing function, *f* as the round function, *block* with *“Initial Permutation”* applied as the data block, and the input key schedule (note that we haven’t yet discussed how the key schedule is computed either). The crux of the algorithm is then the round function *f*. As specified by *feistel*, the function takes *R_i* and *K_i*, and does something dodgy to them. The process is this:

a) Permute *R_i* using the *ebitSelect* P-box.

b) XOR the result with *K_i*.

c) Split the resulting 56-bit value (see below) into eight 6-bit values *B_1* … *B_8*.

d) Run *B_1* … *B_8* through the corresponding S-boxes.

e) Run the result through the *pPerm* P-box.

As I mentioned in the introduction, my knowledge of cryptography is limited to what I’ve read in the past couple of days, so I can’t, unfortunately, detail the requirements on the specific P-boxes or the S-boxes — I understand that the combination of them needs to make *f* a cryptographic PRNG in *K_i*, but I don’t know enough about PRNGs to comment on why those particular transformations, in that particular sequence, are the right thing to do. My guess is that the S-box values are picked to avoid short cycles, and the two permutations lengthen the cycles further.

Once the Feistel network is applied, we apply another P-box, *applyFp*, the *“Forward Permutation.”* This is the inverse of *“Initial Permutation.”* In addition to this, we use a new function, *merge’*, to merge the results into a single ciphertext block *while swapping *L* and *R — which is a simple but subtle detail. Recall how while discussing the deciphering stage of a Feistel network, we noted that deciphering can be done by simply reversing the key schedule, and swapping *L* and *R*. Since the swap is performed at this stage, the decryption function is simply the encryption function with the key schedule reversed. This is nice, because *des keys* encrypts a block, while *des (reverse keys)* decrypts it.

We can write a group of functions to make the above actually usable:

-- Encrypts a text message encryptDES key = map (encryptBlock key) . preparePlaintext -- Decrypts a DES-encoded message decryptDES key = readPlaintext . map (decryptBlock key) -- Encrypts a block encryptBlock key = des (keySchedule (pc1 key)) -- Decrypts a block decryptBlock key = des (reverse $ keySchedule (pc1 key))

Here, *pc1* is yet another P-box, called *“Permuted Choice 1.”* Several new functions have made an appearance: functions to prepare the plaintext (convert it to 64-bit words), read plaintext from a list of 64-bit words, and compute a key schedule from a key.

We’ll jump ahead a little bit and divulge a major detail of the *pc1* P-box. The original DES algorithm operates on 64-bit keys, but the top bit of every byte is used as a parity bit, thereby reducing the actual key size to 56. This is why the S-box step was dividing the key into eight 6-bit blocks: 8*6=56. The key schedule splits that 56-bit value into two 28-bit halves, rotates them 1 or 2 bits to the left depending on a table, merges the result into a key, and feeds that into the next iteration. There are 16 iterations total:

keySchedule key = schedule bits l rwherebits = [1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1] (l,r) = split key 28 schedule []__= [] schedule (n:ns) l r =letl' = rotateL l n r' = rotateL r n k' = merge l' r' 28inpc2 k':schedule ns l' r'

Here, we see our final P-box: *pc2*, which stands for *“Permuted Choice 2.”* We’ve reused split and merge from earlier. All that’s left is the plaintext conversion functions:

preparePlaintext :: [Char] -> [Word64] preparePlaintext text = text64 (to64 text)whereto64 = map (fromIntegral . ord) text64 [] = [] text64 ns =let(x,rest) = splitAt 8 ns w64 = foldl1 shiftOr (reverse x) shiftOr k r = (k `shiftL` 8) .|. rinw64:text64 rest readPlaintext :: [Word64] -> [Char] readPlaintext text = concatMap to8 textwhereto8 = map (chr . fromIntegral) . take 8 . from64 from64 = unfoldr (\k -> Just (k .&. 0xff, k `shiftR` 8))

The first function takes a list of characters and converts them into a list of 64-bit words. The second function reverses the process. Ideally, this would be bit-width aware (in order to support different character encodings), and probably work on ByteStrings instead of [Char], but those are trivial changes.

With the exception of the tables for P-boxes, the only bit left is the S-box implementation. The S-box process is slightly weird in DES. Recall that we called a*pplySboxes *on a list of 6-bit words, *B_1* through *B_8*. There are eight S-boxes, call them *S_1* through *S_8*. Each *S_k* is a 4×16 array of values. The “substitution” part of the S-box comes from the fact that we substitute each *B_k* for *S_k[i][j]*, where i is the 2-bit value composed of the first and last bit of *B_k*, and j is the 4-bit value in the middle. These substitutions are then combined to recreate a 56-bit value. The above can be stated as follows:

applySboxes :: [Word64] -> Word64 applySboxes bs = foldr1 (\r k -> (k `shiftL` 4) .|. r) (apply bs sboxes :: [Word64])whereapply [] [] = [] apply (b:bs) (s:ss) =leti = ((b .&. 0x20) `shiftR` 4) .|. (b .&. 1) j = (b `shiftR` 1) .&. 0xfin((s ! i) ! j):(apply bs ss)

And, with the exception of the actual P-box and S-box data, we’re done. The data is given below, but first, a quick test:

*Crypto> encryptDES 0x0E329232EA6D0D73 “This is an encrypted message”

[13285721053034039710,15909830152601400232,6882462973091073748,1916557633710311361]

*Crypto> decryptDES 0x0E329232EA6D0D73 it

“This is an encrypted message\NUL\NUL\NUL\NUL”

Cool, no? Here’s the really cool part: Comparatively speaking, DES is not very secure at all. Triple DES is very secure. To turn our DES implementation into triple DES, we simply triple the key schedule:

-- Encrypts a text message encryptDES3 keys = map (encryptBlock3 keys) . preparePlaintext -- Decrypts a DES-encoded message decryptDES3 keys = readPlaintext . map (decryptBlock3 keys) -- Encrypts a block encryptBlock3 (k1, k2, k3) = des (keySchedule (pc1 k1) ++ keySchedule (pc1 k2) ++ keySchedule (pc1 k3)) -- Decrypts a block decryptBlock3 (k1, k2, k3) = des (reverse $ keySchedule (pc1 k1) ++ keySchedule (pc1 k2) ++ keySchedule (pc1 k3))

Everything infers the number of rounds from the key schedule, so by pasting three different key schedules on top of each other, we get triple DES with no further work. It’s trivial to actually generalize the encrypt…/decrypt… functions over any number of keys, although I don’t know whether the benefits start to erode after triple DES.

So finally, the moment we’ve all been waiting for, the table data:

-- Permuted Choice 1 pc1 :: Word64 -> Word64 pc1 = permute [57, 49, 41, 33, 25, 17, 9, 1, 58, 50, 42, 34, 26, 18, 10, 2, 59, 51, 43, 35, 27, 19, 11, 3, 60, 52, 44, 36, 63, 55, 47, 39, 31, 23, 15, 7, 62, 54, 46, 38, 30, 22, 14, 6, 61, 53, 45, 37, 29, 21, 13, 5, 28, 20, 12, 4] -- Permuted Choice 2 pc2 :: Word64 -> Word64 pc2 = permute [14, 17, 11, 24, 1, 5, 3, 28, 15, 6, 21, 10, 23, 19, 12, 4, 26, 8, 16, 7, 27, 20, 13, 2, 41, 52, 31, 37, 47, 55, 30, 40, 51, 45, 33, 48, 44, 49, 39, 56, 34, 53, 46, 42, 50, 36, 29, 32] applyIp :: Word64 -> Word64 applyIp = permute [58, 50, 42, 34, 26, 18, 10, 2, 60, 52, 44, 36, 28, 20, 12, 4, 62, 54, 46, 38, 30, 22, 14, 6, 64, 56, 48, 40, 32, 24, 16, 8, 57, 49, 41, 33, 25, 17, 9, 1, 59, 51, 43, 35, 27, 19, 11, 3, 61, 53, 45, 37, 29, 21, 13, 5, 63, 55, 47, 39, 31, 23, 15, 7] applyFp :: Word64 -> Word64 applyFp = permute [40, 8, 48, 16, 56, 24, 64, 32, 39, 7, 47, 15, 55, 23, 63, 31, 38, 6, 46, 14, 54, 22, 62, 30, 37, 5, 45, 13, 53, 21, 61, 29, 36, 4, 44, 12, 52, 20, 60, 28, 35, 3, 43, 11, 51, 19, 59, 27, 34, 2, 42, 10, 50, 18, 58, 26, 33, 1, 41, 9, 49, 17, 57, 25] ebitSelect :: Word64 -> Word64 ebitSelect = permute [32, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 8, 9, 10, 11, 12, 13, 12, 13, 14, 15, 16, 17, 16, 17, 18, 19, 20, 21, 20, 21, 22, 23, 24, 25, 24, 25, 26, 27, 28, 29, 28, 29, 30, 31, 32, 1] pPerm :: Word64 -> Word64 pPerm = permute [16, 7, 20, 21, 29, 12, 28, 17, 1, 15, 23, 26, 5, 18, 31, 10, 2, 8, 24, 14, 32, 27, 3, 9, 19, 13, 30, 6, 22, 11, 4, 25] sboxes :: [Array Word64 (Array Word64 Word64)] sboxes = map (listArray (0,3) . map (listArray (0,15))) -- S-Box 1: [[[14, 4, 13, 1, 2, 15, 11, 8, 3, 10, 6, 12, 5, 9, 0, 7], [0, 15, 7, 4, 14, 2, 13, 1, 10, 6, 12, 11, 9, 5, 3, 8], [4, 1, 14, 8, 13, 6, 2, 11, 15, 12, 9, 7, 3, 10, 5, 0], [15, 12, 8, 2, 4, 9, 1, 7, 5, 11, 3, 14, 10, 0, 6, 13]], -- S-Box 2: [[15, 1, 8, 14, 6, 11, 3, 4, 9, 7, 2, 13, 12, 0, 5, 10], [3, 13, 4, 7, 15, 2, 8, 14, 12, 0, 1, 10, 6, 9, 11, 5], [0, 14, 7, 11, 10, 4, 13, 1, 5, 8, 12, 6, 9, 3, 2, 15], [13, 8, 10, 1, 3, 15, 4, 2, 11, 6, 7, 12, 0, 5, 14, 9]], -- S-Box 3: [[10, 0, 9, 14, 6, 3, 15, 5, 1, 13, 12, 7, 11, 4, 2, 8], [13, 7, 0, 9, 3, 4, 6, 10, 2, 8, 5, 14, 12, 11, 15, 1], [13, 6, 4, 9, 8, 15, 3, 0, 11, 1, 2, 12, 5, 10, 14, 7], [1, 10, 13, 0, 6, 9, 8, 7, 4, 15, 14, 3, 11, 5, 2, 12]], -- S-Box 4: [[7, 13, 14, 3, 0, 6, 9, 10, 1, 2, 8, 5, 11, 12, 4, 15], [13, 8, 11, 5, 6, 15, 0, 3, 4, 7, 2, 12, 1, 10, 14, 9], [10, 6, 9, 0, 12, 11, 7, 13, 15, 1, 3, 14, 5, 2, 8, 4], [3, 15, 0, 6, 10, 1, 13, 8, 9, 4, 5, 11, 12, 7, 2, 14]], -- S-Box 5: [[2, 12, 4, 1, 7, 10, 11, 6, 8, 5, 3, 15, 13, 0, 14, 9], [14, 11, 2, 12, 4, 7, 13, 1, 5, 0, 15, 10, 3, 9, 8, 6], [4, 2, 1, 11, 10, 13, 7, 8, 15, 9, 12, 5, 6, 3, 0, 14], [11, 8, 12, 7, 1, 14, 2, 13, 6, 15, 0, 9, 10, 4, 5, 3]], -- S-Box 6: [[12, 1, 10, 15, 9, 2, 6, 8, 0, 13, 3, 4, 14, 7, 5, 11], [10, 15, 4, 2, 7, 12, 9, 5, 6, 1, 13, 14, 0, 11, 3, 8], [9, 14, 15, 5, 2, 8, 12, 3, 7, 0, 4, 10, 1, 13, 11, 6], [4, 3, 2, 12, 9, 5, 15, 10, 11, 14, 1, 7, 6, 0, 8, 13]], -- S-Box 7: [[4, 11, 2, 14, 15, 0, 8, 13, 3, 12, 9, 7, 5, 10, 6, 1], [13, 0, 11, 7, 4, 9, 1, 10, 14, 3, 5, 12, 2, 15, 8, 6], [1, 4, 11, 13, 12, 3, 7, 14, 10, 15, 6, 8, 0, 5, 9, 2], [6, 11, 13, 8, 1, 4, 10, 7, 9, 5, 0, 15, 14, 2, 3, 12]], -- S-Box 8: [[13, 2, 8, 4, 6, 15, 11, 1, 10, 9, 3, 14, 5, 0, 12, 7], [1, 15, 13, 8, 10, 3, 7, 4, 12, 5, 6, 11, 0, 14, 9, 2], [7, 11, 4, 1, 9, 12, 14, 2, 0, 6, 10, 13, 15, 3, 5, 8], [2, 1, 14, 7, 4, 10, 8, 13, 15, 12, 9, 0, 3, 5, 6, 11]]]

]]>

Arithmetic coding is a remarkably simple and clever thing. The idea is that given some half-open interval [a,b), that is, the interval a <= x < b, we can partition it into half-open subintervals, such that there is one subinterval per character in the message to be encoded, and the lengths correspond to the character frequencies multiplied by b-a. The same procedure is applied, recursively, to each subinterval, resulting in an infinite hierarchy of coverings of the original interval — call it S. Now, if we throw a rock at S, record the point where it hit, and follow the interval hierarchy, we’ll come up with a unique infinite string of characters.

To construct the actual encoding, set S to [0,1), and find out which subinterval S_1 the first character of the message falls into. For the second character, let S_2 be the appropriate subinterval of S_1, for the third character, let S_3 be the appropriate subinterval of S_2, and so on; if we repeat this procedure as many times as there are characters, we’ll arrive at some interval S_n. Numbers that fall in this interval have a useful property: given any such number, call it x, we have x in S_{n-1} (since x is in S_n, and S_n is a subinterval of S_{n-1}), x in S_{n-2} by the same argument, and, by induction, in every subinterval that we picked while encoding the message. Any such x, therefore, uniquely encodes the message: to decode, simply follow the hierarchy.

*Arith> encodeToStream “encodeToStream returns a pair of lists of bytes, represe

nting the numerator and denominator, respectively.”

([174,77,70,217,88,196,42,26,75,253,160,72,114,92,77,135,32,165,50,80,55,77,233,

103,172,90,177,4],[211,29,119,249,50,167,209,90,128,245,114,158,13,236,212,196,1

1,81,64,169,125,254,83,235,75,2,30,13])

*Arith> encode “testing testing testing”

(23,[(‘ ‘,(0%1,2%23)),(‘e’,(2%23,5%23)),(‘g’,(5%23,8%23)),(‘i’,(8%23,11%23)),(‘n

‘,(11%23,14%23)),(‘s’,(14%23,17%23)),(‘t’,(17%23,1%1))],3430733247%4363211066)

*Arith> (decode . encode . decode . encode) “testing testing testing”

“testing testing testing”

*Arith> encode (concat $ replicate 500 “abcd”)

(2000,[(‘a’,(0%1,1%4)),(‘b’,(1%4,1%2)),(‘c’,(1%2,3%4)),(‘d’,(3%4,1%1))],9%85)

The last test shows the output of ‘encode’ : the length of the message is 2000 characters, this is followed by character distributions (in a practical setting, frequencies would be returned instead of explicit intervals), and finally the encoded message. The entire 2000 byte string is encoded in the fraction 9/85.

Toy code follows. As mentioned earlier, ‘encodeToStream’ is a helper function that breaks the fraction into a pair of lists of bytes; the actual encoder and decoder consist of just ‘encode’, ‘decode’ and ‘freqRanges’, weighing in at 23 lines of code including type annotations and line breaks. Gotta love Haskell.

{-# OPTIONS -fglasgow-exts #-}moduleArithwhereimportRatioimportData.ListimportData.MaybeimportData.Charimportqualified Data.MapasMtypeRangeMap k a = [(k, (Ratio a, Ratio a))] encode :: (Ord k, Integral a) => [k] -> (Int, RangeMap k a, Rational) encode msg = (length msg, M.assocs freqMap, best $ foldl pair (0,1) rmap)wherefreqMap = freqRanges msg rmap = map (\x -> fromJust $ M.lookup x freqMap) msg best (a,b) = approxRational ((b+a)/2) ((b-a)/2) pair (a,b) (x,y) = ((b-a)*x+a, (b-a)*y+a) decode :: (Ord a, Integral a) => (Int, RangeMap k a, Ratio a) -> [k] decode (n, freqs, code) = take n $ decode' codewherefindChar x = find (\(c, (a,b)) -> (x >= a) && (x < b)) freqs decode' code =let(Just (c, (x, y))) = findChar codeinc:decode' ((code-x) / (y-x)) freqRanges :: (Ord k, Integral a) => [k] -> M.Map k (Ratio a, Ratio a) freqRanges str = snd $ M.mapAccum (\acc x -> (acc + x, (acc, acc + x))) 0 freqswherefreqs = M.map (\p -> p % total) occurences occurences = foldl (\m c -> M.insertWith (+) c 1 m) M.empty str total = sum (M.elems occurences) encodeToStream msg =let(len, freqs, code) = encode msg (num, denom) = (numerator code, denominator code) bytes n = unfoldr (\k ->ifk == 0thenNothingelseJust (rem k 256, quot k 256)) nin(bytes num, bytes denom)

]]>

a = LSym "a"; b = LSym "b"; c = LSym "c"; p = LSym "p"; q = LSym "q"; r = LSym "r"; s = LSym "s" test = LNot $ ((a:->b):&(b:->c)):->(a:->c) -- contradiction test_ModusPonens = ((p:->q):&p) -- q test_ModusTollens = ((p:->q):&(LNot q)) -- not p test_HypSyllogism = ((p:->q):&(q:->r)) -- p -> r test_DisSyllogism = ((p:|q):&(LNot p)) -- q test_ConstrDilemma = ((p:->q):&(r:->s):&(p:|r)) -- q or s test_DestrDilemma = ((p:->q):&(r:->s):&((LNot q):|(LNot s))) -- not p or not r tests = map (pprint . resolve . compile) [test, test_ModusPonens, test_ModusTollens, test_HypSyllogism, test_DisSyllogism, test_ConstrDilemma, test_DestrDilemma]

The algorithm consists of bringing the expression into conjunctive normal form (I’m simultaneously compiling into a desugared core language), and applying a set of resolution steps. The resolution steps consist of trivially rejecting things like A \/ ~A \/ …, simplifying things like A \/ B \/ A \/ …, and merging expressions of the form (P \/ A \/ …) /\ (~P \/ B \/ …) into A \/ B \/ … . The process clearly terminates; when it does, the resulting expression is what has been inferred from the conjecture. If the result is a contradiction, then the conjecture is false, if the result is empty then nothing can be inferred, and if it’s non-empty, then we’ve proven an inference rule.

Obligatory GHCi session:

*Main> pprint (resolve $ compile (LNot $ ((a :-> b) :& (b :-> c)) :-> (a :-> c)))

“(c & ~c)”

*Main> pprint (resolve $ compile ((p :-> q) :& (LNot q)))

“~p”

The code is below. Haven’t run into any bugs, but I haven’t tested it extensively.

moduleMainwhereimportData.ListimportData.Maybe -- ExpressionsdataExpr a = LSym a | LNot (Expr a) | (Expr a):&(Expr a) | (Expr a):|(Expr a) | (Expr a):->(Expr a)deriving(Show, Eq) -- Minimal core languagedataCExpr a = CSym a | CNot (CExpr a) | CAnd [CExpr a] | COr [CExpr a]deriving(Show, Eq, Ord) -- Desugars and transforms to the core language compile (LSym x) = CSym x compile (a:->b) = COr [compile $ LNot a, compile b] compile (LNot a) = CNot (compile a) compile (a:|b) = COr [compile a, compile b] compile (a:&b) = CAnd [compile a, compile b] -- Transforms to CNF toCNF (CSym a) = (CSym a) toCNF (CNot (COr ts)) = toCNF $ CAnd (map CNot ts) toCNF (CNot (CAnd ts)) = toCNF $ COr (map CNot ts) toCNF (CNot a) = CNot (toCNF a) toCNF (COr ts) =letisConj x =casexof(CAnd_) -> True;_-> False conj = find isConj ts terms = delete (fromJust conj) tsincaseconjofNothing -> COr (map toCNF ts) Just (CAnd e) -> toCNF (CAnd $ map (\t -> COr (t:terms)) e) toCNF (CAnd ts) = CAnd (map toCNF ts) -- Flattens nested connectives, etc. simplify :: Ord a => CExpr a -> CExpr a simplify (CAnd [t]) = t simplify (COr [t]) = t simplify (CAnd terms) = foldl simplifyAnd (CAnd []) (map simplify (sort terms))wheresimplifyAnd (CAnd t) (CAnd t') = CAnd (t ++ t') simplifyAnd (CAnd t) x = CAnd (t ++ [x]) simplify (COr terms) = foldl simplifyOr (COr []) (map simplify (sort terms))wheresimplifyOr (COr t) (COr t') = COr (t ++ t') simplifyOr (COr t) x = COr (t ++ [x]) simplify (CNot (CNot e)) = simplify e simplify (CNot e) = CNot (simplify e) simplify e = e -- Resolution step. Applies reduction patterns as long as the expression is reducible. resolve e =ife' == etheneelseresolve e'wheree' = (toCNF . simplify . reduce) e -- Resolution patterns -- (P or A or ...) and (~P or B or ...) ==> (A or B or ...) reduce (CAnd terms) =letunify [] = [] unify (t:t1:xs) =casemerge t t1ofJust e -> unify (e:xs) Nothing ->let(first:rest) = unify (t:xs)infirst:unify (t1:rest) unify (x:xs) = x:unify xs terms' = map reduce (nub terms)inCAnd (unify (deleteFirstsBy (==) terms' [COr [], CAnd []])) -- P or ~P or A ==> [] and P or P ==> P reduce (COr terms) = COr (map reduce (nub (nontrivial terms)))wherenontrivial [] = [] nontrivial (x:xs) =ifany (isNot x) xsthen[]elsex:nontrivial xs reduce x = x -- Helper functions cNot (CNot x) = x cNot x = CNot x isNot x y = x == cNot y -- Merges expressions of the form (P or A or ... ) and (~P or B or ... ) into A or B or ... merge (COr t1) (COr t2) =letcancel []_= [] cancel (x:xs) ys =casefind (isNot x) ysofNothing -> cancel xs ys Just e -> x:cancel xs (delete e ys)incasecancel t1 t2of[] -> Nothing xs -> Just (COr $ (t1 \\ xs) ++ (t2 \\ (map cNot xs))) merge (COr t) e = merge e (COr t) merge e (COr t) = find (isNot e) t >>= \e' -> Just $ COr (delete e' t) merge e1 e2 = Nothing pprint :: CExpr String -> String pprint (CSym x) = x pprint (CNot t) = "~" ++ pprint t pprint (COr ts) = "(" ++ concat (intersperse " | " (map pprint ts)) ++ ")" pprint (CAnd ts) = "(" ++ concat (intersperse " & " (map pprint ts)) ++ ")"

]]>

The press release makes for an interesting read as well.

]]>

So, let’s say we take the standard definition of the derivative,

,

look at it for a bit, and decide that we don’t like the limit symbol, and that, in fact, we’re going to drop it entirely. After rearranging, we would then obtain the weird-looking f(x) + d f'(x) = f(x+d), and, presumably, set out to find what d could possibly look like.

To this end, we might expand f(x+d) about x, which yields

and, after subtracting f(x) + d f'(x) from both sides, degenerates into

It appears that d^n should be zero for all n>1. To see this, we can plug in f(x) = e^x. The coefficients of d^n become constants, and, dividing both sides by e^x, we obtain d^2/2! + d^3/3! + … = 0, which is the MacLaurin series for e^d with the first two terms missing, so e^d=d+1. Ignoring, momentarily, the troublesome fact that this gives d=0 in the reals, we certainly at least have d^2=0 (pardon the handwaving).

But in order for f(x) + d f'(x) = f(x+d) to be a remotely interesting statement, we must also have d != 0, and we want d to be unique. Given some structure whose objects we’d care to differentiate, we’re going to cheat a little, and extend it with the element d such that d != 0, but d^2 = 0. Having allowed such a number, all the weirdness disappears, the equality f(x+d) = f(x) + d f'(x) holds, and we’re left with a sort of infinitesimal constant which we’re free to plug into random things.

To test whether this makes any sense, we might start by taking the quadratic Q(x) = c_2 x^2 + c_1 x + c_0, and computing Q(x+d):

The derivative of Q(x) “fell out” into the coefficient of d. How about e^{x+d}?

Anything else? Let’s see. By the binomial theorem,

.

The d^{n-k} factor vanishes whenever n >= k+2, so

and we’ve just obtained the power law.

Trigonometric functions:

Using the power law experiment from earlier, we get

Plotting coefficients of d along the vertical axis, and the reals along the horizontal, we get the unit circle, while exponential maps are lines through the origin, which is kind of cool in and out of itself.

Ratios:

One useful thing here is that d is small enough to be in *any* radius of convergence, so we get logarithms “for free”:

So, what’s the point? The point is that we automatically obtain the derivative as a side effect of any computation, since f(x+d) = f(x) + d f'(x). In other words, by switching to dual numbers (numbers of the form x+y d) for things that need to be differentiated (and making it transparent through operator overloading), we can ask for the derivative of any differentiable function we’ve ever defined, and it’ll be evaluated symbolically.

To that end, here’s the first approximation in Haskell.

module Diff where

data Diffable a = !a :+ a

funcPart :: Diffable a -> a

funcPart (x :+ x’) = x

diffPart :: Diffable a -> a

diffPart (x :+ x’) = x’

instance (RealFloat a) => Eq (Diffable a) where

(x :+ x’) == (y :+ y’) = (x == y) && (x’ == y’)

instance (RealFloat a) => Show (Diffable a) where

show (x :+ x’) = show (x, x’)

instance (RealFloat a) => Num (Diffable a) where

(x :+ x’) + (y :+ y’) = (x + y) :+ (x’ + y’)

(x :+ x’) * (y :+ y’) = (x * y) :+ (x’ * y + y’ * x)

abs (x :+ x’) = if x < 0 then ((-x) :+ (-x’)) else (x :+ x’)

signum (x :+ x’) = (signum x) :+ 0

fromInteger x = (fromInteger x) :+ 1

instance (RealFloat a) => Fractional (Diffable a) where

fromRational x = (fromRational x) :+ 1

recip (x :+ x’) = (recip x) :+ (negate x’ / (x^2))

instance (RealFloat a) => Floating (Diffable a) where

pi = pi :+ 0

exp (x :+ x’) = (exp x) :+ (x’ * exp x)

log (x :+ x’) = (log x) :+ (x’ * recip x)

sin (x :+ x’) = (sin x) :+ (x’ * cos x)

cos (x :+ x’) = (cos x) :+ (x’ * (negate $ sin x))

sinh (x :+ x’) = (sinh x) :+ (x’ * cosh x)

cosh (x :+ x’) = (cosh x) :+ (x’ * sinh x)

asin (x :+ x’) = (asin x) :+ (x’ / (sqrt $ 1-x^2))

acos (x :+ x’) = (acos x) :+ (x’ / (negate (sqrt $ 1-x^2)))

atan (x :+ x’) = (atan x) :+ (x’ / (x^2+1))

asinh (x :+ x’) = (asinh x) :+ (x’ / (sqrt $ x^2+1))

acosh (x :+ x’) = (acosh x) :+ (x’ / (negate (sqrt $ x^2-1)))

atanh (x :+ x’) = (atanh x) :+ (x’ / (1 – x^2))

And now let’s try a sample GHCI session:

Derivatives of x^2 for x in {0, 0.5, …, 5}:

*Diff> map (diffPart . (^2) . fromRational) [0,0.5..5]

[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0]

Derivatives of cos^2 x + sin^2 x (should be identically zero):

*Diff> map (diffPart . (\x -> (cos x)^2 + (sin x)^2) . fromRational) [0..10]

[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]

So there we go, AD in half a page of code.

]]>

]]>