module GHC.Stg.InferTags.Rewrite (rewriteTopBinds)
where
import GHC.Prelude
import GHC.Types.Id
import GHC.Types.Name
import GHC.Types.Unique.Supply
import GHC.Types.Unique.FM
import GHC.Types.RepType
import GHC.Unit.Types (Module, isInteractiveModule)
import GHC.Core.DataCon
import GHC.Core (AltCon(..) )
import GHC.Core.Type
import GHC.StgToCmm.Types
import GHC.Stg.Utils
import GHC.Stg.Syntax as StgSyn
import GHC.Data.Maybe
import GHC.Utils.Panic
import GHC.Utils.Panic.Plain
import GHC.Utils.Outputable
import GHC.Utils.Monad.State.Strict
import GHC.Utils.Misc
import GHC.Stg.InferTags.Types
import Control.Monad
import GHC.Types.Basic (CbvMark (NotMarkedCbv, MarkedCbv), isMarkedCbv, TopLevelFlag(..), isTopLevel)
import GHC.Types.Var.Set
newtype RM a = RM { unRM :: (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a) }
deriving (Functor, Monad, Applicative)
instance MonadUnique RM where
getUniqueSupplyM = RM $ do
(m, us, mod,lcls) <- get
let (us1, us2) = splitUniqSupply us
(put) (m,us2,mod,lcls)
return us1
getMap :: RM (UniqFM Id TagSig)
getMap = RM $ ((\(fst,_,_,_) -> fst) <$> get)
setMap :: (UniqFM Id TagSig) -> RM ()
setMap m = RM $ do
(_,us,mod,lcls) <- get
put (m, us,mod,lcls)
getMod :: RM Module
getMod = RM $ ( (\(_,_,thrd,_) -> thrd) <$> get)
getFVs :: RM IdSet
getFVs = RM $ ((\(_,_,_,lcls) -> lcls) <$> get)
setFVs :: IdSet -> RM ()
setFVs fvs = RM $ do
(tag_map,us,mod,_lcls) <- get
put (tag_map, us,mod,fvs)
withBind :: TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind top_flag (StgNonRec bnd _) cont = withBinder top_flag bnd cont
withBind top_flag (StgRec binds) cont = do
let (bnds,_rhss) = unzip binds :: ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
withBinders top_flag bnds cont
addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind (StgNonRec (id, tag) _) = do
s <- getMap
setMap $ addToUFM s id tag
return ()
addTopBind (StgRec binds) = do
let (bnds,_rhss) = unzip binds
!s <- getMap
setMap $! addListToUFM s bnds
withBinder :: TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder top_flag (id,sig) cont = do
oldMap <- getMap
setMap $ addToUFM oldMap id sig
a <- if isTopLevel top_flag
then cont
else withLcl id cont
setMap oldMap
return a
withBinders :: TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevel sigs cont = do
oldMap <- getMap
setMap $ addListToUFM oldMap sigs
a <- cont
setMap oldMap
return a
withBinders NotTopLevel sigs cont = do
oldMap <- getMap
oldFvs <- getFVs
setMap $ addListToUFM oldMap sigs
setFVs $ extendVarSetList oldFvs (map fst sigs)
a <- cont
setMap oldMap
setFVs oldFvs
return a
withClosureLcls :: DIdSet -> RM a -> RM a
withClosureLcls fvs act = do
old_fvs <- getFVs
let fvs' = nonDetStrictFoldDVarSet (flip extendVarSet) old_fvs fvs
setFVs fvs'
r <- act
setFVs old_fvs
return r
withLcl :: Id -> RM a -> RM a
withLcl fv act = do
old_fvs <- getFVs
let fvs' = extendVarSet old_fvs fv
setFVs fvs'
r <- act
setFVs old_fvs
return r
isTagged :: Id -> RM Bool
isTagged v = do
this_mod <- getMod
let lookupDefault v = assertPpr (isInteractiveModule this_mod)
(text "unknown Id:" <> ppr this_mod <+> ppr v)
(TagSig TagDunno)
case nameIsLocalOrFrom this_mod (idName v) of
True
| isUnliftedType (idType v)
-> return True
| otherwise -> do
!s <- getMap
let !sig = lookupWithDefaultUFM s (lookupDefault v) v
return $ case sig of
TagSig info ->
case info of
TagDunno -> False
TagProper -> True
TagTagged -> True
TagTuple _ -> True
False
| Just con <- (isDataConWorkId_maybe v)
, isNullaryRepDataCon con
-> return True
| Just lf_info <- idLFInfo_maybe v
-> return $
case lf_info of
LFReEntrant {}
-> True
LFThunk {}
-> False
LFCon {}
-> True
LFUnknown {}
-> False
LFUnlifted {}
-> True
LFLetNoEscape {}
-> True
| otherwise
-> return False
isArgTagged :: StgArg -> RM Bool
isArgTagged (StgLitArg _) = return True
isArgTagged (StgVarArg v) = isTagged v
mkLocalArgId :: Id -> RM Id
mkLocalArgId id = do
!u <- getUniqueM
return $! setIdUnique (localiseId id) u
rewriteTopBinds :: Module -> UniqSupply -> [GenStgTopBinding 'InferTaggedBinders] -> [TgStgTopBinding]
rewriteTopBinds mod us binds =
let doBinds = mapM rewriteTop binds
in evalState (unRM doBinds) (mempty, us, mod, mempty)
rewriteTop :: InferStgTopBinding -> RM TgStgTopBinding
rewriteTop (StgTopStringLit v s) = return $! (StgTopStringLit v s)
rewriteTop (StgTopLifted bind) = do
addTopBind bind
(StgTopLifted) <$!> (rewriteBinds TopLevel bind)
rewriteBinds :: TopLevelFlag -> InferStgBinding -> RM (TgStgBinding)
rewriteBinds _top_flag (StgNonRec v rhs) = do
(!rhs) <- rewriteRhs v rhs
return $! (StgNonRec (fst v) rhs)
rewriteBinds top_flag b@(StgRec binds) =
withBind top_flag b $ do
(rhss) <- mapM (uncurry rewriteRhs) binds
return $! (mkRec rhss)
where
mkRec :: [TgStgRhs] -> TgStgBinding
mkRec rhss = StgRec (zip (map (fst . fst) binds) rhss)
rewriteRhs :: (Id,TagSig) -> InferStgRhs
-> RM (TgStgRhs)
rewriteRhs (_id, _tagSig) (StgRhsCon ccs con cn ticks args) = do
fieldInfos <- mapM isArgTagged args
let strictFields =
getStrictConArgs con (zip args fieldInfos) :: [(StgArg,Bool)]
let needsEval = map fst .
filter (not . snd) $
strictFields :: [StgArg]
let evalArgs = [v | StgVarArg v <- needsEval] :: [Id]
if (null evalArgs)
then return $! (StgRhsCon ccs con cn ticks args)
else do
let ty_stub = panic "mkSeqs shouldn't use the type arg"
conExpr <- mkSeqs args evalArgs (\taggedArgs -> StgConApp con cn taggedArgs ty_stub)
fvs <- fvArgs args
return $! (StgRhsClosure fvs ccs ReEntrant [] $! conExpr)
rewriteRhs _binding (StgRhsClosure fvs ccs flag args body) = do
withBinders NotTopLevel args $
withClosureLcls fvs $
StgRhsClosure fvs ccs flag (map fst args) <$> rewriteExpr False body
fvArgs :: [StgArg] -> RM DVarSet
fvArgs args = do
fv_lcls <- getFVs
return $ mkDVarSet [ v | StgVarArg v <- args, elemVarSet v fv_lcls]
type IsScrut = Bool
rewriteExpr :: IsScrut -> InferStgExpr -> RM TgStgExpr
rewriteExpr _ (e@StgCase {}) = rewriteCase e
rewriteExpr _ (e@StgLet {}) = rewriteLet e
rewriteExpr _ (e@StgLetNoEscape {}) = rewriteLetNoEscape e
rewriteExpr isScrut (StgTick t e) = StgTick t <$!> rewriteExpr isScrut e
rewriteExpr _ e@(StgConApp {}) = rewriteConApp e
rewriteExpr isScrut e@(StgApp {}) = rewriteApp isScrut e
rewriteExpr _ (StgLit lit) = return $! (StgLit lit)
rewriteExpr _ (StgOpApp op args res_ty) = return $! (StgOpApp op args res_ty)
rewriteCase :: InferStgExpr -> RM TgStgExpr
rewriteCase (StgCase scrut bndr alt_type alts) =
withBinder NotTopLevel bndr $
pure StgCase <*>
rewriteExpr True scrut <*>
pure (fst bndr) <*>
pure alt_type <*>
mapM rewriteAlt alts
rewriteCase _ = panic "Impossible: nodeCase"
rewriteAlt :: InferStgAlt -> RM TgStgAlt
rewriteAlt alt@GenStgAlt{alt_con=_, alt_bndrs=bndrs, alt_rhs=rhs} =
withBinders NotTopLevel bndrs $ do
!rhs' <- rewriteExpr False rhs
return $! alt {alt_bndrs = map fst bndrs, alt_rhs = rhs'}
rewriteLet :: InferStgExpr -> RM TgStgExpr
rewriteLet (StgLet xt bind expr) = do
(!bind') <- rewriteBinds NotTopLevel bind
withBind NotTopLevel bind $ do
!expr' <- rewriteExpr False expr
return $! (StgLet xt bind' expr')
rewriteLet _ = panic "Impossible"
rewriteLetNoEscape :: InferStgExpr -> RM TgStgExpr
rewriteLetNoEscape (StgLetNoEscape xt bind expr) = do
(!bind') <- rewriteBinds NotTopLevel bind
withBind NotTopLevel bind $ do
!expr' <- rewriteExpr False expr
return $! (StgLetNoEscape xt bind' expr')
rewriteLetNoEscape _ = panic "Impossible"
rewriteConApp :: InferStgExpr -> RM TgStgExpr
rewriteConApp (StgConApp con cn args tys) = do
fieldInfos <- mapM isArgTagged args
let strictIndices = getStrictConArgs con (zip fieldInfos args) :: [(Bool, StgArg)]
let needsEval = map snd . filter (not . fst) $ strictIndices :: [StgArg]
let evalArgs = [v | StgVarArg v <- needsEval] :: [Id]
if (not $ null evalArgs)
then do
mkSeqs args evalArgs (\taggedArgs -> StgConApp con cn taggedArgs tys)
else return $! (StgConApp con cn args tys)
rewriteConApp _ = panic "Impossible"
rewriteApp :: IsScrut -> InferStgExpr -> RM TgStgExpr
rewriteApp True (StgApp f []) = do
f_tagged <- isTagged f
let f' = if f_tagged
then setIdTagSig f (TagSig TagProper)
else f
return $! StgApp f' []
rewriteApp _ (StgApp f args)
| Just marks <- idCbvMarks_maybe f
, relevant_marks <- dropWhileEndLE (not . isMarkedCbv) marks
, any isMarkedCbv relevant_marks
= assert (length relevant_marks <= length args)
unliftArg relevant_marks
where
unliftArg relevant_marks = do
argTags <- mapM isArgTagged args
let argInfo = zipWith3 ((,,)) args (relevant_marks++repeat NotMarkedCbv) argTags :: [(StgArg, CbvMark, Bool)]
cbvArgInfo = filter (\x -> sndOf3 x == MarkedCbv && thdOf3 x == False) argInfo
cbvArgIds = [x | StgVarArg x <- map fstOf3 cbvArgInfo] :: [Id]
mkSeqs args cbvArgIds (\cbv_args -> StgApp f cbv_args)
rewriteApp _ (StgApp f args) = return $ StgApp f args
rewriteApp _ _ = panic "Impossible"
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq id bndr !expr =
let altTy = mkStgAltTypeFromStgAlts bndr alt
alt = [GenStgAlt {alt_con = DEFAULT, alt_bndrs = [], alt_rhs = expr}]
in StgCase (StgApp id []) bndr altTy alt
mkSeqs :: [StgArg]
-> [Id]
-> ([StgArg] -> TgStgExpr)
-> RM TgStgExpr
mkSeqs args untaggedIds mkExpr = do
argMap <- mapM (\arg -> (arg,) <$> mkLocalArgId arg ) untaggedIds :: RM [(InId, OutId)]
let taggedArgs :: [StgArg]
= map (\v -> case v of
StgVarArg v' -> StgVarArg $ fromMaybe v' $ lookup v' argMap
lit -> lit)
args
let conBody = mkExpr taggedArgs
let body = foldr (\(v,bndr) expr -> mkSeq v bndr expr) conBody argMap
return $! body
getStrictConArgs :: DataCon -> [a] -> [a]
getStrictConArgs con args
| isUnboxedTupleDataCon con = []
| isUnboxedSumDataCon con = []
| otherwise =
[ arg | (arg,MarkedStrict)
<- zipEqual "getStrictConArgs"
args
(dataConRuntimeRepStrictness con)]