在x86-64 / Windows下如何正确切换上下文?

内容来源于 Stack Overflow,并遵循CC BY-SA 3.0许可协议进行翻译与使用

  • 回答 (1)
  • 关注 (0)
  • 查看 (86)

我正在为x86-64实现我自己的光纤库。它的部分动机是缺乏跨平台的标准上下文切换(GCC / Linux具有makecontext,它将void *作为可变参数,并且Windows的光纤API需要1个void * arg)以及API设计和实现方面的学习练习。在我的API中,一个协程函数需要两个参数:一个协程上下文和一个void *参数,正在学习如何工作。

struct win64_mcontext {
  U64 rdi, rsi, rbx, rbp, r12, r13, r14, r15;
  U64 rax, rsp, rip;
  U64 rcx, rdx, r8, r9;
};

struct coroutine {
  struct win64_mcontext caller;
  struct win64_mcontext callee;
  U32 state;
};

void coprepare(struct coroutine **co,
           void *stack, U64 stack_size, cofunc_t func)
{
  *co = malloc(sizeof **co); /* TODO: replace with something cheaper */
  _coprepare(&(*co)->caller, &(*co)->callee, stack, stack_size, func);
}

void coenter(struct coroutine *co, void *enter_arg)
{
   _coenter(&co->caller, &co->callee, enter_arg);
}

void coyield(struct coroutine *co, void *yield_arg)
{
  _coyield(&co->callee, &co->caller, yield_arg);
}

int  coresume(struct coroutine *co)
{
  _coresume(&co->caller, &co->callee);
  return 0; /* punt this for now */
}

_coenter,_coyield和_coresume都是作为jmp __cotransfer实现的

;;; _coprepare(struct win64_mcontext *old, struct win64_mcontext *new,
;;;            void *stack, U64 stack_size,
;;;            cofunc_t func);
;;; RCX     -> old
;;; RDX     -> new
;;; R8      -> stack
;;; R9      -> stack_size
;;; RSP + ? -> func
_coprepare proc
    ;; save non-volatile GPRs in 'old'
    mov [RCX + OFF_RSI], RSI
    mov [RCX + OFF_RDI], RDI
    mov [RCX + OFF_RBP], RBP
    mov [RCX + OFF_RBX], RBX
    mov [RCX + OFF_R12], R12
    mov [RCX + OFF_R13], R13
    mov [RCX + OFF_R14], R14
    mov [RCX + OFF_R15], R15

    ;; save stack frame info in 'old'
    mov R10, RSP
    mov R11, OFFSET _coyield

    mov [RCX + OFF_RSP], R10
    mov [RCX + OFF_RIP], R11

    ;; init non-volatile GPRs in 'new'
    lea R10, [R8 + R9]       ; new RSP, = stack + stack_size
    lea R11, [RBP - 32]  ; load func

    xor EAX, EAX
    mov [RDX + OFF_RSI], RAX
    mov [RDX + OFF_RDI], RAX
    mov [RDX + OFF_RBX], RAX
    mov [RDX + OFF_RBP], R10
    mov [RDX + OFF_R12], RAX
    mov [RDX + OFF_R13], RAX
    mov [RDX + OFF_R14], RAX
    mov [RDX + OFF_R15], RAX

    mov [RDX + OFF_RSP], R10
    mov [RDX + OFF_RIP], R11

    ret
_coprepare endp

;;; __cotransfer(struct win64_context *old, struct win64_mcontext *new, void *trans_arg);
;;; RCX : old
;;; RDX : new
;;; R8  : trans_arg
__cotransfer proc
    ;; save non-volatile GPRs
    mov [RCX + OFF_RSI], RSI
    mov [RCX + OFF_RDI], RDI
    mov [RCX + OFF_RBX], RBX
    mov [RCX + OFF_RBP], RBP
    mov [RCX + OFF_R12], R12
    mov [RCX + OFF_R13], R13
    mov [RCX + OFF_R14], R14
    mov [RCX + OFF_R15], R15

    ;; save argument GPRs
    mov [RCX + OFF_RCX], RCX
    mov [RCX + OFF_RDX], RDX
    mov [RCX + OFF_R8], R8
    mov [RCX + OFF_R9], R9

    ;; save stack frame info
    lea R10, [RSP - 8]  ; save SP, exclude IP
    lea R11, [RSP]      ; save IP

    mov [RCX + OFF_RSP], R10
    mov [RCX + OFF_RIP], R11

    ;; switch stacks
    mov RAX, RSP
    mov RSP, [RDX + OFF_RSP]
    mov [RCX + OFF_RSP], RAX

    ;; load non-volatile GPRs
    mov RSI, [RDX + OFF_RSI]
    mov RDI, [RDX + OFF_RDI]
    mov RBX, [RDX + OFF_RBX]
    mov RBP, [RDX + OFF_RBP]
    mov R12, [RDX + OFF_R12]
    mov R13, [RDX + OFF_R13]
    mov R14, [RDX + OFF_R14]
    mov R15, [RDX + OFF_R15]

    ;; load argument registers
    mov R10, RCX
    mov R11, RDX

    mov RCX, [R11 + OFF_RCX]
    mov RDX, [R11 + OFF_RDX]
    mov R8,  [R11 + OFF_R8]
    mov R9,  [R11 + OFF_R9]

    ; push new return address
    mov RAX, [R11 + OFF_RIP]
    push RAX        
    ret ; jump to new return address
__cotransfer endp

它总是在__cotransfer的某个地方崩溃。我无法确定在调试过程中最终发生了什么

提问于
用户回答回答于

窗口的最小实现(当然这里已经准备好实现)可以看起来像:

c / c ++部分:

typedef struct _INITIAL_TEB
{
    PVOID OldStackBase;
    PVOID OldStackLimit;
    PVOID StackBase;
    PVOID StackLimit;
    PVOID StackAllocationBase;
} INITIAL_TEB, *PINITIAL_TEB;

extern "C"
NTSYSAPI 
NTSTATUS 
NTAPI RtlFreeUserStack  (   _In_ PVOID      AllocationBase  );

extern "C"
NTSYSAPI 
NTSTATUS 
NTAPI   
RtlCreateUserStack (
                    _In_opt_ SIZE_T CommittedStackSize, 
                    _In_opt_ SIZE_T MaximumStackSize, 
                    _In_opt_ ULONG_PTR ZeroBits, 
                    _In_ SIZE_T PageSize, 
                    _In_ ULONG_PTR ReserveAlignment, 
                    _Out_ PINITIAL_TEB InitialTeb);

struct FIBER_CONTEXT
{
    NT_TIB Tib;
    PVOID StackPointer;
    PVOID StackAllocationBase;
};

extern "C"
{
    void __cdecl FiberStart();
    void __fastcall SwitchToContext(FIBER_CONTEXT* ctx);
}

FIBER_CONTEXT* MyConvertThreadToFiber()
{
    if (FIBER_CONTEXT* ctx = new FIBER_CONTEXT)
    {
        ((NT_TIB*)NtCurrentTeb())->FiberData = ctx;
        return ctx;
    }

    return 0;
}

void MyConvertFiberToThread()
{
    if (FIBER_CONTEXT* ctx = (FIBER_CONTEXT*)((NT_TIB*)NtCurrentTeb())->FiberData)
    {
        delete ctx;
        ((NT_TIB*)NtCurrentTeb())->FiberData = 0;
    }
}

FIBER_CONTEXT* WINAPI MyCreateFiber(
                          __in      SIZE_T dwStackSize,
                          __in      PFIBER_START_ROUTINE lpStartAddress,
                          __in_opt  PVOID lpParameter
                          )
{
    INITIAL_TEB InitialTeb;
    NTSTATUS status = RtlCreateUserStack(0, dwStackSize, 0, 0x1000, 0x10000, &InitialTeb);

    if (0 <= status)
    {
        if (FIBER_CONTEXT* ctx = new FIBER_CONTEXT)
        {
            ctx->StackAllocationBase = InitialTeb.StackAllocationBase;
            NT_TIB* Tib = ((NT_TIB*)NtCurrentTeb());

            ctx->Tib.ArbitraryUserPointer = 0;
            ctx->Tib.ExceptionList = 0;
            ctx->Tib.FiberData = ctx;
            ctx->Tib.StackBase = InitialTeb.StackBase;
            ctx->Tib.StackLimit = InitialTeb.StackLimit;
            ctx->Tib.SubSystemTib = Tib->SubSystemTib;
            ctx->Tib.Self = Tib->Self;

            void** StackBase = (void**)InitialTeb.StackBase;
            ctx->StackPointer = StackBase - (4 + 1 + 8);
            StackBase[-3] = lpStartAddress;
            StackBase[-4] = lpParameter;
            StackBase[-5] = FiberStart;
            return ctx;
        }
        RtlFreeUserStack(InitialTeb.StackAllocationBase);
    }

    return 0;
}

VOID WINAPI MyDeleteFiber(FIBER_CONTEXT* ctx)
{
    RtlFreeUserStack(ctx->StackAllocationBase);
    delete ctx;
}

asm(对于x64)实现部分:

NT_TIB STRUCT
    ExceptionList DQ ?
    StackBase DQ ?
    StackLimit DQ ?
    SubSystemTib DQ ?
    FiberData DQ ?
    ArbitraryUserPointer DQ ?
    Self DQ ?
NT_TIB ENDS

FIBER_CONTEXT STRUCT
    Tib NT_TIB <?>
    StackPointer DQ ?
FIBER_CONTEXT ENDS

extern __imp_ExitThread:QWORD

_TEXT segment 'CODE'

FiberStart proc
    mov     rcx,[rsp]
    call    qword ptr [rsp + 8]
    mov     ecx,eax
    call    [__imp_ExitThread]
FiberStart endp

SwitchToContext proc
    push    r15
    push    r14
    push    r13
    push    r12
    push    rsi
    push    rdi
    push    rbx
    push    rbp

    mov     rax,gs:[NT_TIB.Self]    ; rax -> NT_TIB

    mov     rdx,[rax + NT_TIB.FiberData]    ; current fiber data

    mov     [rdx + FIBER_CONTEXT.StackPointer],rsp  ; save current rsp
    mov     rsp,[rcx + FIBER_CONTEXT.StackPointer]  ; set new rsp

    ; save NT_TIB
    lea     rdi,[rdx + FIBER_CONTEXT.Tib]
    mov     rsi,rax
    mov     rdx,rcx
    mov     rcx, SIZEOF NT_TIB / SIZEOF QWORD
    rep     movsq

    ; set NT_TIB
    mov     rdi,rax
    lea     rsi,[rdx + FIBER_CONTEXT.Tib]
    mov     rcx, SIZEOF NT_TIB / SIZEOF QWORD
    rep     movsq

    pop     rbp
    pop     rbx
    pop     rdi
    pop     rsi
    pop     r12
    pop     r13
    pop     r14
    pop     r15
    ret
SwitchToContext endp

_TEXT ENDS

END

使用示例:

struct FCTX 
{
    FIBER_CONTEXT* MainFiber, *WorkFiber;
    PCSTR sz;
};

void WINAPI FiberProc(FCTX* ctx)
{
    for (;;)
    {
        DbgPrint("%s\n", ctx->sz);
        SwitchToContext(ctx->MainFiber);
    }
}

void test()
{
    FCTX ctx;
    if (ctx.MainFiber = MyConvertThreadToFiber())
    {
        if (ctx.WorkFiber = MyCreateFiber(0, (PFIBER_START_ROUTINE)FiberProc, &ctx))
        {
            ctx.sz = "task #1";
            SwitchToContext(ctx.WorkFiber);
            ctx.sz = "task #2";
            SwitchToContext(ctx.WorkFiber);
            MyDeleteFiber(ctx.WorkFiber);
        }
        MyConvertFiberToThread();
    }
}

扫码关注云+社区

领取腾讯云代金券