{-# LANGUAGE PatternGuards, ScopedTypeVariables, BangPatterns, Trustworthy #-}
module Text.EditDistance.SquareSTUArray (
levenshteinDistance, levenshteinDistanceWithLengths, restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
) where
import Text.EditDistance.EditCosts
import Text.EditDistance.MonadUtilities
import Text.EditDistance.ArrayUtilities
import Control.Monad hiding (foldM)
import Control.Monad.ST
import Data.Array.ST
levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance !EditCosts
costs String
str1 String
str2 = EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2
where
str1_len :: Int
str1_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str1
str2_len :: Int
str2_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str2
levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !EditCosts
costs !Int
str1_len !Int
str2_len String
str1 String
str2 = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST (EditCosts -> Int -> Int -> String -> String -> ST s Int
forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2)
levenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST :: forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST !EditCosts
costs !Int
str1_len !Int
str2_len String
str1 String
str2 = do
STUArray s Int Char
str1_array <- String -> Int -> ST s (STUArray s Int Char)
forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str1 Int
str1_len
STUArray s Int Char
str2_array <- String -> Int -> ST s (STUArray s Int Char)
forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str2 Int
str2_len
STUArray s (Int, Int) Int
cost_array <- ((Int, Int), (Int, Int)) -> ST s (STUArray s (Int, Int) Int)
forall i. Ix i => (i, i) -> ST s (STUArray s i Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ ((Int
0, Int
0), (Int
str1_len, Int
str2_len)) :: ST s (STUArray s (Int, Int) Int)
Int -> ST s Char
read_str1 <- STUArray s Int Char -> ST s (Int -> ST s Char)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> m e)
unsafeReadArray' STUArray s Int Char
str1_array
Int -> ST s Char
read_str2 <- STUArray s Int Char -> ST s (Int -> ST s Char)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> m e)
unsafeReadArray' STUArray s Int Char
str2_array
(Int, Int) -> ST s Int
read_cost <- STUArray s (Int, Int) Int -> ST s ((Int, Int) -> ST s Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> m e)
unsafeReadArray' STUArray s (Int, Int) Int
cost_array
(Int, Int) -> Int -> ST s ()
write_cost <- STUArray s (Int, Int) Int -> ST s ((Int, Int) -> Int -> ST s ())
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> e -> m ())
unsafeWriteArray' STUArray s (Int, Int) Int
cost_array
(Int, Int)
_ <- (\(Int, Int) -> Char -> ST s (Int, Int)
f -> ((Int, Int) -> Char -> ST s (Int, Int))
-> (Int, Int) -> String -> ST s (Int, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM (Int, Int) -> Char -> ST s (Int, Int)
f (Int
1, Int
0) String
str1) (((Int, Int) -> Char -> ST s (Int, Int)) -> ST s (Int, Int))
-> ((Int, Int) -> Char -> ST s (Int, Int)) -> ST s (Int, Int)
forall a b. (a -> b) -> a -> b
$ \(Int
i, Int
deletion_cost) Char
col_char -> let deletion_cost' :: Int
deletion_cost' = Int
deletion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char in (Int, Int) -> Int -> ST s ()
write_cost (Int
i, Int
0) Int
deletion_cost' ST s () -> ST s (Int, Int) -> ST s (Int, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
deletion_cost')
Int
_ <- (\Int -> Int -> ST s Int
f -> (Int -> Int -> ST s Int) -> Int -> [Int] -> ST s Int
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM Int -> Int -> ST s Int
f Int
0 [Int
1..Int
str2_len]) ((Int -> Int -> ST s Int) -> ST s Int)
-> (Int -> Int -> ST s Int) -> ST s Int
forall a b. (a -> b) -> a -> b
$ \Int
insertion_cost (!Int
j) -> do
Char
row_char <- Int -> ST s Char
read_str2 Int
j
let insertion_cost' :: Int
insertion_cost' = Int
insertion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
row_char
(Int, Int) -> Int -> ST s ()
write_cost (Int
0, Int
j) Int
insertion_cost'
Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
loopM_ Int
1 Int
str1_len ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(!Int
i) -> do
Char
col_char <- Int -> ST s Char
read_str1 Int
i
Int
cost <- EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
forall s.
EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
standardCosts EditCosts
costs (Int, Int) -> ST s Int
read_cost Char
row_char Char
col_char (Int
i, Int
j)
(Int, Int) -> Int -> ST s ()
write_cost (Int
i, Int
j) Int
cost
Int -> ST s Int
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
insertion_cost'
(Int, Int) -> ST s Int
read_cost (Int
str1_len, Int
str2_len)
restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance EditCosts
costs String
str1 String
str2 = EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2
where
str1_len :: Int
str1_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str1
str2_len :: Int
str2_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str2
restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2 = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST (EditCosts -> Int -> Int -> String -> String -> ST s Int
forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2)
restrictedDamerauLevenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST :: forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST !EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2 = do
STUArray s Int Char
str1_array <- String -> Int -> ST s (STUArray s Int Char)
forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str1 Int
str1_len
STUArray s Int Char
str2_array <- String -> Int -> ST s (STUArray s Int Char)
forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str2 Int
str2_len
STUArray s (Int, Int) Int
cost_array <- ((Int, Int), (Int, Int)) -> ST s (STUArray s (Int, Int) Int)
forall i. Ix i => (i, i) -> ST s (STUArray s i Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ ((Int
0, Int
0), (Int
str1_len, Int
str2_len)) :: ST s (STUArray s (Int, Int) Int)
Int -> ST s Char
read_str1 <- STUArray s Int Char -> ST s (Int -> ST s Char)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> m e)
unsafeReadArray' STUArray s Int Char
str1_array
Int -> ST s Char
read_str2 <- STUArray s Int Char -> ST s (Int -> ST s Char)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> m e)
unsafeReadArray' STUArray s Int Char
str2_array
(Int, Int) -> ST s Int
read_cost <- STUArray s (Int, Int) Int -> ST s ((Int, Int) -> ST s Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> m e)
unsafeReadArray' STUArray s (Int, Int) Int
cost_array
(Int, Int) -> Int -> ST s ()
write_cost <- STUArray s (Int, Int) Int -> ST s ((Int, Int) -> Int -> ST s ())
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> e -> m ())
unsafeWriteArray' STUArray s (Int, Int) Int
cost_array
(Int, Int)
_ <- (\(Int, Int) -> Char -> ST s (Int, Int)
f -> ((Int, Int) -> Char -> ST s (Int, Int))
-> (Int, Int) -> String -> ST s (Int, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM (Int, Int) -> Char -> ST s (Int, Int)
f (Int
1, Int
0) String
str1) (((Int, Int) -> Char -> ST s (Int, Int)) -> ST s (Int, Int))
-> ((Int, Int) -> Char -> ST s (Int, Int)) -> ST s (Int, Int)
forall a b. (a -> b) -> a -> b
$ \(Int
i, Int
deletion_cost) Char
col_char -> let deletion_cost' :: Int
deletion_cost' = Int
deletion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char in (Int, Int) -> Int -> ST s ()
write_cost (Int
i, Int
0) Int
deletion_cost' ST s () -> ST s (Int, Int) -> ST s (Int, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
deletion_cost')
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
str2_len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
Char
initial_row_char <- Int -> ST s Char
read_str2 Int
1
(Int, Int) -> Int -> ST s ()
write_cost (Int
0, Int
1) (EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
initial_row_char)
Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
loopM_ Int
1 Int
str1_len ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(!Int
i) -> do
Char
col_char <- Int -> ST s Char
read_str1 Int
i
Int
cost <- EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
forall s.
EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
standardCosts EditCosts
costs (Int, Int) -> ST s Int
read_cost Char
initial_row_char Char
col_char (Int
i, Int
1)
(Int, Int) -> Int -> ST s ()
write_cost (Int
i, Int
1) Int
cost
Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
loopM_ Int
2 Int
str2_len (\(!Int
j) -> do
Char
row_char <- Int -> ST s Char
read_str2 Int
j
Char
prev_row_char <- Int -> ST s Char
read_str2 (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
(Int, Int) -> Int -> ST s ()
write_cost (Int
0, Int
j) (EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
row_char Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
j)
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
str1_len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
Char
col_char <- Int -> ST s Char
read_str1 Int
1
Int
cost <- EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
forall s.
EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
standardCosts EditCosts
costs (Int, Int) -> ST s Int
read_cost Char
row_char Char
col_char (Int
1, Int
j)
(Int, Int) -> Int -> ST s ()
write_cost (Int
1, Int
j) Int
cost
Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
loopM_ Int
2 Int
str1_len (\(!Int
i) -> do
Char
col_char <- Int -> ST s Char
read_str1 Int
i
Char
prev_col_char <- Int -> ST s Char
read_str1 (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
Int
standard_cost <- EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
forall s.
EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
standardCosts EditCosts
costs (Int, Int) -> ST s Int
read_cost Char
row_char Char
col_char (Int
i, Int
j)
Int
cost <- if Char
prev_row_char Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
col_char Bool -> Bool -> Bool
&& Char
prev_col_char Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
row_char
then do Int
transpose_cost <- (Int -> Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (EditCosts -> Char -> Char -> Int
transpositionCost EditCosts
costs Char
col_char Char
row_char)) (ST s Int -> ST s Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> ST s Int
read_cost (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2, Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2)
Int -> ST s Int
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
standard_cost Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
transpose_cost)
else Int -> ST s Int
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
standard_cost
(Int, Int) -> Int -> ST s ()
write_cost (Int
i, Int
j) Int
cost))
(Int, Int) -> ST s Int
read_cost (Int
str1_len, Int
str2_len)
{-# INLINE standardCosts #-}
standardCosts :: EditCosts -> ((Int, Int) -> ST s Int) -> Char -> Char -> (Int, Int) -> ST s Int
standardCosts :: forall s.
EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
standardCosts !EditCosts
costs (Int, Int) -> ST s Int
read_cost !Char
row_char !Char
col_char (!Int
i, !Int
j) = do
Int
deletion_cost <- (Int -> Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char)) (ST s Int -> ST s Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> ST s Int
read_cost (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
j)
Int
insertion_cost <- (Int -> Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
row_char)) (ST s Int -> ST s Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> ST s Int
read_cost (Int
i, Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
Int
subst_cost <- (Int -> Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Char
row_char Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
col_char
then Int
0
else (EditCosts -> Char -> Char -> Int
substitutionCost EditCosts
costs Char
col_char Char
row_char))
((Int, Int) -> ST s Int
read_cost (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
Int -> ST s Int
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> ST s Int) -> Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ Int
deletion_cost Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
insertion_cost Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
subst_cost