module GHC.Tc.Gen.Match
( tcMatchesFun
, tcGRHS
, tcGRHSsPat
, tcMatchesCase
, tcMatchLambda
, TcMatchCtxt(..)
, TcStmtChecker
, TcExprStmtChecker
, TcCmdStmtChecker
, tcStmts
, tcStmtsAndThen
, tcDoStmts
, tcBody
, tcDoStmt
, tcGuardStmt
, checkArgCounts
)
where
import GHC.Prelude
import GHC.Tc.Gen.Expr( tcSyntaxOp, tcInferRho, tcInferRhoNC
, tcMonoExpr, tcMonoExprNC, tcExpr
, tcCheckMonoExpr, tcCheckMonoExprNC
, tcCheckPolyExpr )
import GHC.Tc.Errors.Types
import GHC.Tc.Utils.Monad
import GHC.Tc.Utils.Env
import GHC.Tc.Gen.Pat
import GHC.Tc.Gen.Head( tcCheckId )
import GHC.Tc.Utils.TcMType
import GHC.Tc.Utils.TcType
import GHC.Tc.Gen.Bind
import GHC.Tc.Utils.Concrete ( hasFixedRuntimeRep_syntactic )
import GHC.Tc.Utils.Unify
import GHC.Tc.Types.Origin
import GHC.Tc.Types.Evidence
import GHC.Core.Multiplicity
import GHC.Core.UsageEnv
import GHC.Core.TyCon
import GHC.Core.Make
import GHC.Hs
import GHC.Builtin.Types
import GHC.Builtin.Types.Prim
import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Utils.Misc
import GHC.Driver.Session ( getDynFlags )
import GHC.Types.Error
import GHC.Types.Fixity (LexicalFixity(..))
import GHC.Types.Name
import GHC.Types.Id
import GHC.Types.SrcLoc
import Control.Monad
import Control.Arrow ( second )
tcMatchesFun :: LocatedN Name
-> MatchGroup GhcRn (LHsExpr GhcRn)
-> ExpRhoType
-> TcM (HsWrapper, MatchGroup GhcTc (LHsExpr GhcTc))
tcMatchesFun fun_name matches exp_ty
= do {
traceTc "tcMatchesFun" (ppr fun_name $$ ppr exp_ty)
; checkArgCounts what matches
; matchExpectedFunTys herald ctxt arity exp_ty $ \ pat_tys rhs_ty ->
tcScalingUsage Many $
tcMatches match_ctxt pat_tys rhs_ty matches }
where
arity = matchGroupArity matches
herald = ExpectedFunTyMatches (NameThing (unLoc fun_name)) matches
ctxt = GenSigCtxt
what = FunRhs { mc_fun = fun_name, mc_fixity = Prefix, mc_strictness = strictness }
match_ctxt = MC { mc_what = what, mc_body = tcBody }
strictness
| [L _ match] <- unLoc $ mg_alts matches
, FunRhs{ mc_strictness = SrcStrict } <- m_ctxt match
= SrcStrict
| otherwise
= NoSrcStrict
tcMatchesCase :: (AnnoBody body) =>
TcMatchCtxt body
-> Scaled TcSigmaTypeFRR
-> MatchGroup GhcRn (LocatedA (body GhcRn))
-> ExpRhoType
-> TcM (MatchGroup GhcTc (LocatedA (body GhcTc)))
tcMatchesCase ctxt (Scaled scrut_mult scrut_ty) matches res_ty
= tcMatches ctxt [Scaled scrut_mult (mkCheckExpType scrut_ty)] res_ty matches
tcMatchLambda :: ExpectedFunTyOrigin
-> TcMatchCtxt HsExpr
-> MatchGroup GhcRn (LHsExpr GhcRn)
-> ExpRhoType
-> TcM (HsWrapper, MatchGroup GhcTc (LHsExpr GhcTc))
tcMatchLambda herald match_ctxt match res_ty
= do { checkArgCounts (mc_what match_ctxt) match
; matchExpectedFunTys herald GenSigCtxt n_pats res_ty $ \ pat_tys rhs_ty -> do
tcMatches match_ctxt pat_tys rhs_ty match }
where
n_pats | isEmptyMatchGroup match = 1
| otherwise = matchGroupArity match
tcGRHSsPat :: GRHSs GhcRn (LHsExpr GhcRn) -> ExpRhoType
-> TcM (GRHSs GhcTc (LHsExpr GhcTc))
tcGRHSsPat grhss res_ty
= tcScalingUsage Many $
tcGRHSs match_ctxt grhss res_ty
where
match_ctxt :: TcMatchCtxt HsExpr
match_ctxt = MC { mc_what = PatBindRhs,
mc_body = tcBody }
data TcMatchCtxt body
= MC { mc_what :: HsMatchContext GhcTc,
mc_body :: LocatedA (body GhcRn)
-> ExpRhoType
-> TcM (LocatedA (body GhcTc)) }
type AnnoBody body
= ( Outputable (body GhcRn)
, Anno (Match GhcRn (LocatedA (body GhcRn))) ~ SrcSpanAnnA
, Anno (Match GhcTc (LocatedA (body GhcTc))) ~ SrcSpanAnnA
, Anno [LocatedA (Match GhcRn (LocatedA (body GhcRn)))] ~ SrcSpanAnnL
, Anno [LocatedA (Match GhcTc (LocatedA (body GhcTc)))] ~ SrcSpanAnnL
, Anno (GRHS GhcRn (LocatedA (body GhcRn))) ~ SrcAnn NoEpAnns
, Anno (GRHS GhcTc (LocatedA (body GhcTc))) ~ SrcAnn NoEpAnns
, Anno (StmtLR GhcRn GhcRn (LocatedA (body GhcRn))) ~ SrcSpanAnnA
, Anno (StmtLR GhcTc GhcTc (LocatedA (body GhcTc))) ~ SrcSpanAnnA
)
tcMatches :: (AnnoBody body ) => TcMatchCtxt body
-> [Scaled ExpSigmaTypeFRR]
-> ExpRhoType
-> MatchGroup GhcRn (LocatedA (body GhcRn))
-> TcM (MatchGroup GhcTc (LocatedA (body GhcTc)))
tcMatches ctxt pat_tys rhs_ty (MG { mg_alts = L l matches
, mg_origin = origin })
| null matches
= do { tcEmitBindingUsage bottomUE
; pat_tys <- mapM scaledExpTypeToType pat_tys
; rhs_ty <- expTypeToType rhs_ty
; return (MG { mg_alts = L l []
, mg_ext = MatchGroupTc pat_tys rhs_ty
, mg_origin = origin }) }
| otherwise
= do { umatches <- mapM (tcCollectingUsage . tcMatch ctxt pat_tys rhs_ty) matches
; let (usages,matches') = unzip umatches
; tcEmitBindingUsage $ supUEs usages
; pat_tys <- mapM readScaledExpType pat_tys
; rhs_ty <- readExpType rhs_ty
; return (MG { mg_alts = L l matches'
, mg_ext = MatchGroupTc pat_tys rhs_ty
, mg_origin = origin }) }
tcMatch :: (AnnoBody body) => TcMatchCtxt body
-> [Scaled ExpSigmaType]
-> ExpRhoType
-> LMatch GhcRn (LocatedA (body GhcRn))
-> TcM (LMatch GhcTc (LocatedA (body GhcTc)))
tcMatch ctxt pat_tys rhs_ty match
= wrapLocMA (tc_match ctxt pat_tys rhs_ty) match
where
tc_match ctxt pat_tys rhs_ty
match@(Match { m_pats = pats, m_grhss = grhss })
= add_match_ctxt match $
do { (pats', grhss') <- tcPats (mc_what ctxt) pats pat_tys $
tcGRHSs ctxt grhss rhs_ty
; return (Match { m_ext = noAnn
, m_ctxt = mc_what ctxt, m_pats = pats'
, m_grhss = grhss' }) }
add_match_ctxt match thing_inside
= case mc_what ctxt of
LambdaExpr -> thing_inside
_ -> addErrCtxt (pprMatchInCtxt match) thing_inside
tcGRHSs :: AnnoBody body
=> TcMatchCtxt body -> GRHSs GhcRn (LocatedA (body GhcRn)) -> ExpRhoType
-> TcM (GRHSs GhcTc (LocatedA (body GhcTc)))
tcGRHSs ctxt (GRHSs _ grhss binds) res_ty
= do { (binds', ugrhss)
<- tcLocalBinds binds $
mapM (tcCollectingUsage . wrapLocMA (tcGRHS ctxt res_ty)) grhss
; let (usages, grhss') = unzip ugrhss
; tcEmitBindingUsage $ supUEs usages
; return (GRHSs emptyComments grhss' binds') }
tcGRHS :: TcMatchCtxt body -> ExpRhoType -> GRHS GhcRn (LocatedA (body GhcRn))
-> TcM (GRHS GhcTc (LocatedA (body GhcTc)))
tcGRHS ctxt res_ty (GRHS _ guards rhs)
= do { (guards', rhs')
<- tcStmtsAndThen stmt_ctxt tcGuardStmt guards res_ty $
mc_body ctxt rhs
; return (GRHS noAnn guards' rhs') }
where
stmt_ctxt = PatGuard (mc_what ctxt)
tcDoStmts :: HsDoFlavour
-> LocatedL [LStmt GhcRn (LHsExpr GhcRn)]
-> ExpRhoType
-> TcM (HsExpr GhcTc)
tcDoStmts ListComp (L l stmts) res_ty
= do { res_ty <- expTypeToType res_ty
; (co, elt_ty) <- matchExpectedListTy res_ty
; let list_ty = mkListTy elt_ty
; stmts' <- tcStmts (HsDoStmt ListComp) (tcLcStmt listTyCon) stmts
(mkCheckExpType elt_ty)
; return $ mkHsWrapCo co (HsDo list_ty ListComp (L l stmts')) }
tcDoStmts doExpr@(DoExpr _) (L l stmts) res_ty
= do { stmts' <- tcStmts (HsDoStmt doExpr) tcDoStmt stmts res_ty
; res_ty <- readExpType res_ty
; return (HsDo res_ty doExpr (L l stmts')) }
tcDoStmts mDoExpr@(MDoExpr _) (L l stmts) res_ty
= do { stmts' <- tcStmts (HsDoStmt mDoExpr) tcDoStmt stmts res_ty
; res_ty <- readExpType res_ty
; return (HsDo res_ty mDoExpr (L l stmts')) }
tcDoStmts MonadComp (L l stmts) res_ty
= do { stmts' <- tcStmts (HsDoStmt MonadComp) tcMcStmt stmts res_ty
; res_ty <- readExpType res_ty
; return (HsDo res_ty MonadComp (L l stmts')) }
tcDoStmts ctxt@GhciStmtCtxt _ _ = pprPanic "tcDoStmts" (pprHsDoFlavour ctxt)
tcBody :: LHsExpr GhcRn -> ExpRhoType -> TcM (LHsExpr GhcTc)
tcBody body res_ty
= do { traceTc "tcBody" (ppr res_ty)
; tcMonoExpr body res_ty
}
type TcExprStmtChecker = TcStmtChecker HsExpr ExpRhoType
type TcCmdStmtChecker = TcStmtChecker HsCmd TcRhoType
type TcStmtChecker body rho_type
= forall thing. HsStmtContext GhcTc
-> Stmt GhcRn (LocatedA (body GhcRn))
-> rho_type
-> (rho_type -> TcM thing)
-> TcM (Stmt GhcTc (LocatedA (body GhcTc)), thing)
tcStmts :: (AnnoBody body) => HsStmtContext GhcTc
-> TcStmtChecker body rho_type
-> [LStmt GhcRn (LocatedA (body GhcRn))]
-> rho_type
-> TcM [LStmt GhcTc (LocatedA (body GhcTc))]
tcStmts ctxt stmt_chk stmts res_ty
= do { (stmts', _) <- tcStmtsAndThen ctxt stmt_chk stmts res_ty $
const (return ())
; return stmts' }
tcStmtsAndThen :: (AnnoBody body) => HsStmtContext GhcTc
-> TcStmtChecker body rho_type
-> [LStmt GhcRn (LocatedA (body GhcRn))]
-> rho_type
-> (rho_type -> TcM thing)
-> TcM ([LStmt GhcTc (LocatedA (body GhcTc))], thing)
tcStmtsAndThen _ _ [] res_ty thing_inside
= do { thing <- thing_inside res_ty
; return ([], thing) }
tcStmtsAndThen ctxt stmt_chk (L loc (LetStmt x binds) : stmts)
res_ty thing_inside
= do { (binds', (stmts',thing)) <- tcLocalBinds binds $
tcStmtsAndThen ctxt stmt_chk stmts res_ty thing_inside
; return (L loc (LetStmt x binds') : stmts', thing) }
tcStmtsAndThen ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside
| ApplicativeStmt{} <- stmt
= do { (stmt', (stmts', thing)) <-
stmt_chk ctxt stmt res_ty $ \ res_ty' ->
tcStmtsAndThen ctxt stmt_chk stmts res_ty' $
thing_inside
; return (L loc stmt' : stmts', thing) }
| otherwise
= do { (stmt', (stmts', thing)) <-
setSrcSpanA loc $
addErrCtxt (pprStmtInCtxt ctxt stmt) $
stmt_chk ctxt stmt res_ty $ \ res_ty' ->
popErrCtxt $
tcStmtsAndThen ctxt stmt_chk stmts res_ty' $
thing_inside
; return (L loc stmt' : stmts', thing) }
tcGuardStmt :: TcExprStmtChecker
tcGuardStmt _ (BodyStmt _ guard _ _) res_ty thing_inside
= do { guard' <- tcScalingUsage Many $ tcCheckMonoExpr guard boolTy
; thing <- thing_inside res_ty
; return (BodyStmt boolTy guard' noSyntaxExpr noSyntaxExpr, thing) }
tcGuardStmt ctxt (BindStmt _ pat rhs) res_ty thing_inside
= do {
(rhs', rhs_ty) <- tcScalingUsage Many $ tcInferRhoNC rhs
; hasFixedRuntimeRep_syntactic FRRBindStmtGuard rhs_ty
; (pat', thing) <- tcCheckPat_O (StmtCtxt ctxt) (lexprCtOrigin rhs)
pat (unrestricted rhs_ty) $
thing_inside res_ty
; return (mkTcBindStmt pat' rhs', thing) }
tcGuardStmt _ stmt _ _
= pprPanic "tcGuardStmt: unexpected Stmt" (ppr stmt)
tcLcStmt :: TyCon
-> TcExprStmtChecker
tcLcStmt _ _ (LastStmt x body noret _) elt_ty thing_inside
= do { body' <- tcMonoExprNC body elt_ty
; thing <- thing_inside (panic "tcLcStmt: thing_inside")
; return (LastStmt x body' noret noSyntaxExpr, thing) }
tcLcStmt m_tc ctxt (BindStmt _ pat rhs) elt_ty thing_inside
= do { pat_ty <- newFlexiTyVarTy liftedTypeKind
; rhs' <- tcCheckMonoExpr rhs (mkTyConApp m_tc [pat_ty])
; (pat', thing) <- tcCheckPat (StmtCtxt ctxt) pat (unrestricted pat_ty) $
thing_inside elt_ty
; return (mkTcBindStmt pat' rhs', thing) }
tcLcStmt _ _ (BodyStmt _ rhs _ _) elt_ty thing_inside
= do { rhs' <- tcCheckMonoExpr rhs boolTy
; thing <- thing_inside elt_ty
; return (BodyStmt boolTy rhs' noSyntaxExpr noSyntaxExpr, thing) }
tcLcStmt m_tc ctxt (ParStmt _ bndr_stmts_s _ _) elt_ty thing_inside
= do { (pairs', thing) <- loop bndr_stmts_s
; return (ParStmt unitTy pairs' noExpr noSyntaxExpr, thing) }
where
loop [] = do { thing <- thing_inside elt_ty
; return ([], thing) }
loop (ParStmtBlock x stmts names _ : pairs)
= do { (stmts', (ids, pairs', thing))
<- tcStmtsAndThen ctxt (tcLcStmt m_tc) stmts elt_ty $ \ _elt_ty' ->
do { ids <- tcLookupLocalIds names
; (pairs', thing) <- loop pairs
; return (ids, pairs', thing) }
; return ( ParStmtBlock x stmts' ids noSyntaxExpr : pairs', thing ) }
tcLcStmt m_tc ctxt (TransStmt { trS_form = form, trS_stmts = stmts
, trS_bndrs = bindersMap
, trS_by = by, trS_using = using }) elt_ty thing_inside
= do { let (bndr_names, n_bndr_names) = unzip bindersMap
unused_ty = pprPanic "tcLcStmt: inner ty" (ppr bindersMap)
; (stmts', (bndr_ids, by'))
<- tcStmtsAndThen (TransStmtCtxt ctxt) (tcLcStmt m_tc) stmts unused_ty $ \_ -> do
{ by' <- traverse tcInferRho by
; bndr_ids <- tcLookupLocalIds bndr_names
; return (bndr_ids, by') }
; let m_app ty = mkTyConApp m_tc [ty]
; let n_app = case form of
ThenForm -> (\ty -> ty)
_ -> m_app
by_arrow :: Type -> Type
by_arrow = case by' of
Nothing -> \ty -> ty
Just (_,e_ty) -> \ty -> (alphaTy `mkVisFunTyMany` e_ty) `mkVisFunTyMany` ty
tup_ty = mkBigCoreVarTupTy bndr_ids
poly_arg_ty = m_app alphaTy
poly_res_ty = m_app (n_app alphaTy)
using_poly_ty = mkInfForAllTy alphaTyVar $
by_arrow $
poly_arg_ty `mkVisFunTyMany` poly_res_ty
; using' <- tcCheckPolyExpr using using_poly_ty
; let final_using = fmap (mkHsWrap (WpTyApp tup_ty)) using'
; let mk_n_bndr :: Name -> TcId -> TcId
mk_n_bndr n_bndr_name bndr_id = mkLocalId n_bndr_name Many (n_app (idType bndr_id))
n_bndr_ids = zipWith mk_n_bndr n_bndr_names bndr_ids
bindersMap' = bndr_ids `zip` n_bndr_ids
; thing <- tcExtendIdEnv n_bndr_ids (thing_inside elt_ty)
; return (TransStmt { trS_stmts = stmts', trS_bndrs = bindersMap'
, trS_by = fmap fst by', trS_using = final_using
, trS_ret = noSyntaxExpr
, trS_bind = noSyntaxExpr
, trS_fmap = noExpr
, trS_ext = unitTy
, trS_form = form }, thing) }
tcLcStmt _ _ stmt _ _
= pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt)
tcMcStmt :: TcExprStmtChecker
tcMcStmt _ (LastStmt x body noret return_op) res_ty thing_inside
= do { (body', return_op')
<- tcSyntaxOp MCompOrigin return_op [SynRho] res_ty $
\ [a_ty] [mult]->
tcScalingUsage mult $ tcCheckMonoExprNC body a_ty
; thing <- thing_inside (panic "tcMcStmt: thing_inside")
; return (LastStmt x body' noret return_op', thing) }
tcMcStmt ctxt (BindStmt xbsrn pat rhs) res_ty thing_inside
= do { ((rhs_ty, rhs', pat_mult, pat', thing, new_res_ty), bind_op')
<- tcSyntaxOp MCompOrigin (xbsrn_bindOp xbsrn)
[SynRho, SynFun SynAny SynRho] res_ty $
\ [rhs_ty, pat_ty, new_res_ty] [rhs_mult, fun_mult, pat_mult] ->
do { rhs' <- tcScalingUsage rhs_mult $ tcCheckMonoExprNC rhs rhs_ty
; (pat', thing) <- tcScalingUsage fun_mult $ tcCheckPat (StmtCtxt ctxt) pat (Scaled pat_mult pat_ty) $
thing_inside (mkCheckExpType new_res_ty)
; return (rhs_ty, rhs', pat_mult, pat', thing, new_res_ty) }
; hasFixedRuntimeRep_syntactic (FRRBindStmt MonadComprehension) rhs_ty
; fail_op' <- fmap join . forM (xbsrn_failOp xbsrn) $ \fail ->
tcMonadFailOp (MCompPatOrigin pat) pat' fail new_res_ty
; let xbstc = XBindStmtTc
{ xbstc_bindOp = bind_op'
, xbstc_boundResultType = new_res_ty
, xbstc_boundResultMult = pat_mult
, xbstc_failOp = fail_op'
}
; return (BindStmt xbstc pat' rhs', thing) }
tcMcStmt _ (BodyStmt _ rhs then_op guard_op) res_ty thing_inside
= do {
; ((thing, rhs', rhs_ty, new_res_ty, test_ty, guard_op'), then_op')
<- tcSyntaxOp MCompOrigin then_op [SynRho, SynRho] res_ty $
\ [rhs_ty, new_res_ty] [rhs_mult, fun_mult] ->
do { ((rhs', test_ty), guard_op')
<- tcScalingUsage rhs_mult $
tcSyntaxOp MCompOrigin guard_op [SynAny]
(mkCheckExpType rhs_ty) $
\ [test_ty] [test_mult] -> do
rhs' <- tcScalingUsage test_mult $ tcCheckMonoExpr rhs test_ty
return $ (rhs', test_ty)
; thing <- tcScalingUsage fun_mult $ thing_inside (mkCheckExpType new_res_ty)
; return (thing, rhs', rhs_ty, new_res_ty, test_ty, guard_op') }
; hasFixedRuntimeRep_syntactic FRRBodyStmtGuard test_ty
; hasFixedRuntimeRep_syntactic (FRRBodyStmt MonadComprehension 1) rhs_ty
; hasFixedRuntimeRep_syntactic (FRRBodyStmt MonadComprehension 2) new_res_ty
; return (BodyStmt rhs_ty rhs' then_op' guard_op', thing) }
tcMcStmt ctxt (TransStmt { trS_stmts = stmts, trS_bndrs = bindersMap
, trS_by = by, trS_using = using, trS_form = form
, trS_ret = return_op, trS_bind = bind_op
, trS_fmap = fmap_op }) res_ty thing_inside
= do { m1_ty <- newFlexiTyVarTy typeToTypeKind
; m2_ty <- newFlexiTyVarTy typeToTypeKind
; tup_ty <- newFlexiTyVarTy liftedTypeKind
; by_e_ty <- newFlexiTyVarTy liftedTypeKind
; n_app <- case form of
ThenForm -> return (\ty -> ty)
_ -> do { n_ty <- newFlexiTyVarTy typeToTypeKind
; return (n_ty `mkAppTy`) }
; let by_arrow :: Type -> Type
by_arrow = case by of
Nothing -> \res -> res
Just {} -> \res -> (alphaTy `mkVisFunTyMany` by_e_ty) `mkVisFunTyMany` res
poly_arg_ty = m1_ty `mkAppTy` alphaTy
using_arg_ty = m1_ty `mkAppTy` tup_ty
poly_res_ty = m2_ty `mkAppTy` n_app alphaTy
using_res_ty = m2_ty `mkAppTy` n_app tup_ty
using_poly_ty = mkInfForAllTy alphaTyVar $
by_arrow $
poly_arg_ty `mkVisFunTyMany` poly_res_ty
; let (bndr_names, n_bndr_names) = unzip bindersMap
; (stmts', (bndr_ids, by', return_op')) <-
tcStmtsAndThen (TransStmtCtxt ctxt) tcMcStmt stmts
(mkCheckExpType using_arg_ty) $ \res_ty' -> do
{ by' <- case by of
Nothing -> return Nothing
Just e -> do { e' <- tcCheckMonoExpr e by_e_ty
; return (Just e') }
; bndr_ids <- tcLookupLocalIds bndr_names
; (_, return_op') <- tcSyntaxOp MCompOrigin return_op
[synKnownType (mkBigCoreVarTupTy bndr_ids)]
res_ty' $ \ _ _ -> return ()
; return (bndr_ids, by', return_op') }
; new_res_ty <- newFlexiTyVarTy liftedTypeKind
; (_, bind_op') <- tcSyntaxOp MCompOrigin bind_op
[ synKnownType using_res_ty
, synKnownType (n_app tup_ty `mkVisFunTyMany` new_res_ty) ]
res_ty $ \ _ _ -> return ()
; fmap_op' <- case form of
ThenForm -> return noExpr
_ -> fmap unLoc . tcCheckPolyExpr (noLocA fmap_op) $
mkInfForAllTy alphaTyVar $
mkInfForAllTy betaTyVar $
(alphaTy `mkVisFunTyMany` betaTy)
`mkVisFunTyMany` (n_app alphaTy)
`mkVisFunTyMany` (n_app betaTy)
; using' <- tcCheckPolyExpr using using_poly_ty
; let final_using = fmap (mkHsWrap (WpTyApp tup_ty)) using'
; let mk_n_bndr :: Name -> TcId -> TcId
mk_n_bndr n_bndr_name bndr_id = mkLocalId n_bndr_name Many (n_app (idType bndr_id))
n_bndr_ids = zipWithEqual "tcMcStmt" mk_n_bndr n_bndr_names bndr_ids
bindersMap' = bndr_ids `zip` n_bndr_ids
; thing <- tcExtendIdEnv n_bndr_ids $
thing_inside (mkCheckExpType new_res_ty)
; return (TransStmt { trS_stmts = stmts', trS_bndrs = bindersMap'
, trS_by = by', trS_using = final_using
, trS_ret = return_op', trS_bind = bind_op'
, trS_ext = n_app tup_ty
, trS_fmap = fmap_op', trS_form = form }, thing) }
tcMcStmt ctxt (ParStmt _ bndr_stmts_s mzip_op bind_op) res_ty thing_inside
= do { m_ty <- newFlexiTyVarTy typeToTypeKind
; let mzip_ty = mkInfForAllTys [alphaTyVar, betaTyVar] $
(m_ty `mkAppTy` alphaTy)
`mkVisFunTyMany`
(m_ty `mkAppTy` betaTy)
`mkVisFunTyMany`
(m_ty `mkAppTy` mkBoxedTupleTy [alphaTy, betaTy])
; mzip_op' <- unLoc `fmap` tcCheckPolyExpr (noLocA mzip_op) mzip_ty
; id_tys_s <- (mapM . mapM) (const (newFlexiTyVarTy liftedTypeKind))
[ names | ParStmtBlock _ _ names _ <- bndr_stmts_s ]
; let tup_tys = [ mkBigCoreTupTy id_tys | id_tys <- id_tys_s ]
tuple_ty = mk_tuple_ty tup_tys
; (((blocks', thing), inner_res_ty), bind_op')
<- tcSyntaxOp MCompOrigin bind_op
[ synKnownType (m_ty `mkAppTy` tuple_ty)
, SynFun (synKnownType tuple_ty) SynRho ] res_ty $
\ [inner_res_ty] _ ->
do { stuff <- loop m_ty (mkCheckExpType inner_res_ty)
tup_tys bndr_stmts_s
; return (stuff, inner_res_ty) }
; return (ParStmt inner_res_ty blocks' mzip_op' bind_op', thing) }
where
mk_tuple_ty tys = foldr1 (\tn tm -> mkBoxedTupleTy [tn, tm]) tys
loop _ inner_res_ty [] [] = do { thing <- thing_inside inner_res_ty
; return ([], thing) }
loop m_ty inner_res_ty (tup_ty_in : tup_tys_in)
(ParStmtBlock x stmts names return_op : pairs)
= do { let m_tup_ty = m_ty `mkAppTy` tup_ty_in
; (stmts', (ids, return_op', pairs', thing))
<- tcStmtsAndThen ctxt tcMcStmt stmts (mkCheckExpType m_tup_ty) $
\m_tup_ty' ->
do { ids <- tcLookupLocalIds names
; let tup_ty = mkBigCoreVarTupTy ids
; (_, return_op') <-
tcSyntaxOp MCompOrigin return_op
[synKnownType tup_ty] m_tup_ty' $
\ _ _ -> return ()
; (pairs', thing) <- loop m_ty inner_res_ty tup_tys_in pairs
; return (ids, return_op', pairs', thing) }
; return (ParStmtBlock x stmts' ids return_op' : pairs', thing) }
loop _ _ _ _ = panic "tcMcStmt.loop"
tcMcStmt _ stmt _ _
= pprPanic "tcMcStmt: unexpected Stmt" (ppr stmt)
tcDoStmt :: TcExprStmtChecker
tcDoStmt _ (LastStmt x body noret _) res_ty thing_inside
= do { body' <- tcMonoExprNC body res_ty
; thing <- thing_inside (panic "tcDoStmt: thing_inside")
; return (LastStmt x body' noret noSyntaxExpr, thing) }
tcDoStmt ctxt (BindStmt xbsrn pat rhs) res_ty thing_inside
= do {
((rhs_ty, rhs', pat_mult, pat', new_res_ty, thing), bind_op')
<- tcSyntaxOp DoOrigin (xbsrn_bindOp xbsrn) [SynRho, SynFun SynAny SynRho] res_ty $
\ [rhs_ty, pat_ty, new_res_ty] [rhs_mult,fun_mult,pat_mult] ->
do { rhs' <-tcScalingUsage rhs_mult $ tcCheckMonoExprNC rhs rhs_ty
; (pat', thing) <- tcScalingUsage fun_mult $ tcCheckPat (StmtCtxt ctxt) pat (Scaled pat_mult pat_ty) $
thing_inside (mkCheckExpType new_res_ty)
; return (rhs_ty, rhs', pat_mult, pat', new_res_ty, thing) }
; hasFixedRuntimeRep_syntactic (FRRBindStmt DoNotation) rhs_ty
; fail_op' <- fmap join . forM (xbsrn_failOp xbsrn) $ \fail ->
tcMonadFailOp (DoPatOrigin pat) pat' fail new_res_ty
; let xbstc = XBindStmtTc
{ xbstc_bindOp = bind_op'
, xbstc_boundResultType = new_res_ty
, xbstc_boundResultMult = pat_mult
, xbstc_failOp = fail_op'
}
; return (BindStmt xbstc pat' rhs', thing) }
tcDoStmt ctxt (ApplicativeStmt _ pairs mb_join) res_ty thing_inside
= do { let tc_app_stmts ty = tcApplicativeStmts ctxt pairs ty $
thing_inside . mkCheckExpType
; ((pairs', body_ty, thing), mb_join') <- case mb_join of
Nothing -> (, Nothing) <$> tc_app_stmts res_ty
Just join_op ->
second Just <$>
(tcSyntaxOp DoOrigin join_op [SynRho] res_ty $
\ [rhs_ty] [rhs_mult] -> tcScalingUsage rhs_mult $ tc_app_stmts (mkCheckExpType rhs_ty))
; return (ApplicativeStmt body_ty pairs' mb_join', thing) }
tcDoStmt _ (BodyStmt _ rhs then_op _) res_ty thing_inside
= do {
; ((rhs', rhs_ty, new_res_ty, thing), then_op')
<- tcSyntaxOp DoOrigin then_op [SynRho, SynRho] res_ty $
\ [rhs_ty, new_res_ty] [rhs_mult,fun_mult] ->
do { rhs' <- tcScalingUsage rhs_mult $ tcCheckMonoExprNC rhs rhs_ty
; thing <- tcScalingUsage fun_mult $ thing_inside (mkCheckExpType new_res_ty)
; return (rhs', rhs_ty, new_res_ty, thing) }
; hasFixedRuntimeRep_syntactic (FRRBodyStmt DoNotation 1) rhs_ty
; hasFixedRuntimeRep_syntactic (FRRBodyStmt DoNotation 2) new_res_ty
; return (BodyStmt rhs_ty rhs' then_op' noSyntaxExpr, thing) }
tcDoStmt ctxt (RecStmt { recS_stmts = L l stmts, recS_later_ids = later_names
, recS_rec_ids = rec_names, recS_ret_fn = ret_op
, recS_mfix_fn = mfix_op, recS_bind_fn = bind_op })
res_ty thing_inside
= do { let tup_names = rec_names ++ filterOut (`elem` rec_names) later_names
; tup_elt_tys <- newFlexiTyVarTys (length tup_names) liftedTypeKind
; let tup_ids = zipWith (\n t -> mkLocalId n Many t) tup_names tup_elt_tys
tup_ty = mkBigCoreTupTy tup_elt_tys
; tcExtendIdEnv tup_ids $ do
{ ((stmts', (ret_op', tup_rets)), stmts_ty)
<- tcInfer $ \ exp_ty ->
tcStmtsAndThen ctxt tcDoStmt stmts exp_ty $ \ inner_res_ty ->
do { tup_rets <- zipWithM tcCheckId tup_names
(map mkCheckExpType tup_elt_tys)
; (_, ret_op')
<- tcSyntaxOp DoOrigin ret_op [synKnownType tup_ty]
inner_res_ty $ \_ _ -> return ()
; return (ret_op', tup_rets) }
; ((_, mfix_op'), mfix_res_ty)
<- tcInfer $ \ exp_ty ->
tcSyntaxOp DoOrigin mfix_op
[synKnownType (mkVisFunTyMany tup_ty stmts_ty)] exp_ty $
\ _ _ -> return ()
; ((thing, new_res_ty), bind_op')
<- tcSyntaxOp DoOrigin bind_op
[ synKnownType mfix_res_ty
, SynFun (synKnownType tup_ty) SynRho ]
res_ty $
\ [new_res_ty] _ ->
do { thing <- thing_inside (mkCheckExpType new_res_ty)
; return (thing, new_res_ty) }
; let rec_ids = takeList rec_names tup_ids
; later_ids <- tcLookupLocalIds later_names
; traceTc "tcdo" $ vcat [ppr rec_ids <+> ppr (map idType rec_ids),
ppr later_ids <+> ppr (map idType later_ids)]
; return (RecStmt { recS_stmts = L l stmts', recS_later_ids = later_ids
, recS_rec_ids = rec_ids, recS_ret_fn = ret_op'
, recS_mfix_fn = mfix_op', recS_bind_fn = bind_op'
, recS_ext = RecStmtTc
{ recS_bind_ty = new_res_ty
, recS_later_rets = []
, recS_rec_rets = tup_rets
, recS_ret_ty = stmts_ty} }, thing)
}}
tcDoStmt _ stmt _ _
= pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)
tcMonadFailOp :: CtOrigin
-> LPat GhcTc
-> SyntaxExpr GhcRn
-> TcType
-> TcRn (FailOperator GhcTc)
tcMonadFailOp orig pat fail_op res_ty = do
dflags <- getDynFlags
if isIrrefutableHsPat dflags pat
then return Nothing
else Just . snd <$> (tcSyntaxOp orig fail_op [synKnownType stringTy]
(mkCheckExpType res_ty) $ \_ _ -> return ())
tcApplicativeStmts
:: HsStmtContext GhcTc
-> [(SyntaxExpr GhcRn, ApplicativeArg GhcRn)]
-> ExpRhoType
-> (TcRhoType -> TcM t)
-> TcM ([(SyntaxExpr GhcTc, ApplicativeArg GhcTc)], Type, t)
tcApplicativeStmts ctxt pairs rhs_ty thing_inside
= do { body_ty <- newFlexiTyVarTy liftedTypeKind
; let arity = length pairs
; ts <- replicateM (arity1) $ newInferExpType
; exp_tys <- replicateM arity $ newFlexiTyVarTy liftedTypeKind
; pat_tys <- replicateM arity $ newFlexiTyVarTy liftedTypeKind
; let fun_ty = mkVisFunTysMany pat_tys body_ty
; let (ops, args) = unzip pairs
; ops' <- goOps fun_ty (zip3 ops (ts ++ [rhs_ty]) exp_tys)
; args' <- mapM (goArg body_ty) (zip3 args pat_tys exp_tys)
; res <- tcExtendIdEnv (concatMap get_arg_bndrs args') $
thing_inside body_ty
; return (zip ops' args', body_ty, res) }
where
goOps _ [] = return []
goOps t_left ((op,t_i,exp_ty) : ops)
= do { (_, op')
<- tcSyntaxOp DoOrigin op
[synKnownType t_left, synKnownType exp_ty] t_i $
\ _ _ -> return ()
; t_i <- readExpType t_i
; ops' <- goOps t_i ops
; return (op' : ops') }
goArg :: Type -> (ApplicativeArg GhcRn, Type, Type)
-> TcM (ApplicativeArg GhcTc)
goArg body_ty (ApplicativeArgOne
{ xarg_app_arg_one = fail_op
, app_arg_pattern = pat
, arg_expr = rhs
, ..
}, pat_ty, exp_ty)
= setSrcSpan (combineSrcSpans (getLocA pat) (getLocA rhs)) $
addErrCtxt (pprStmtInCtxt ctxt (mkRnBindStmt pat rhs)) $
do { rhs' <- tcCheckMonoExprNC rhs exp_ty
; (pat', _) <- tcCheckPat (StmtCtxt ctxt) pat (unrestricted pat_ty) $
return ()
; fail_op' <- fmap join . forM fail_op $ \fail ->
tcMonadFailOp (DoPatOrigin pat) pat' fail body_ty
; return (ApplicativeArgOne
{ xarg_app_arg_one = fail_op'
, app_arg_pattern = pat'
, arg_expr = rhs'
, .. }
) }
goArg _body_ty (ApplicativeArgMany x stmts ret pat ctxt, pat_ty, exp_ty)
= do { (stmts', (ret',pat')) <-
tcStmtsAndThen (HsDoStmt ctxt) tcDoStmt stmts (mkCheckExpType exp_ty) $
\res_ty -> do
{ ret' <- tcExpr ret res_ty
; (pat', _) <- tcCheckPat (StmtCtxt (HsDoStmt ctxt)) pat (unrestricted pat_ty) $
return ()
; return (ret', pat')
}
; return (ApplicativeArgMany x stmts' ret' pat' ctxt) }
get_arg_bndrs :: ApplicativeArg GhcTc -> [Id]
get_arg_bndrs (ApplicativeArgOne { app_arg_pattern = pat }) = collectPatBinders CollNoDictBinders pat
get_arg_bndrs (ApplicativeArgMany { bv_pattern = pat }) = collectPatBinders CollNoDictBinders pat
checkArgCounts :: AnnoBody body
=> HsMatchContext GhcTc -> MatchGroup GhcRn (LocatedA (body GhcRn))
-> TcM ()
checkArgCounts _ (MG { mg_alts = L _ [] })
= return ()
checkArgCounts matchContext (MG { mg_alts = L _ (match1:matches) })
| not (null bad_matches)
= failWithTc $ TcRnUnknownMessage $ mkPlainError noHints $
(vcat [ err_msg <+>
text "have different numbers of arguments"
, nest 2 (ppr (getLocA match1))
, nest 2 (ppr (getLocA (head bad_matches)))])
| otherwise
= return ()
where
n_args1 = args_in_match match1
bad_matches = [m | m <- matches, args_in_match m /= n_args1]
err_msg = pprMatchContextNouns matchContext
args_in_match :: (LocatedA (Match GhcRn body1) -> Int)
args_in_match (L _ (Match { m_pats = pats })) = length pats