🎉 Support oidc by telegram

This commit is contained in:
omg-xtao 2024-01-14 16:04:25 +08:00 committed by GitHub
parent 43d0f26dda
commit 4fbe5eb0dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 2543 additions and 1 deletions

12
.env.example Normal file
View File

@ -0,0 +1,12 @@
CONN_URI=sqlite+aiosqlite:///data/db.sqlite3
DEBUG=True
PROJECT_URL=http://127.0.0.1
PROJECT_LOGIN_SUCCESS_URL=http://google.com
PROJECT_PORT=80
JWT_PRIVATE_KEY='data/private_key'
JWT_PUBLIC_KEY='data/public_key'
BOT_TOKEN=xxx
BOT_USERNAME=xxxxBot
BOT_API_ID=111
BOT_API_HASH=aaa
BOT_MANAGER_IDS=[111,222]

4
.gitignore vendored
View File

@ -157,4 +157,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
data/

32
README.md Normal file
View File

@ -0,0 +1,32 @@
# Telegram OAuth
## Configuration
```dotenv
CONN_URI=sqlite+aiosqlite:///data/db.sqlite3 # 数据库 uri
DEBUG=True # 调试模式
PROJECT_URL=http://127.0.0.1 # 项目可访问的地址
PROJECT_LOGIN_SUCCESS_URL=http://google.com # 登录成功后跳转的地址
PROJECT_PORT=80 # 项目运行的端口
JWT_PRIVATE_KEY='data/private_key' # jwt 私钥
JWT_PUBLIC_KEY='data/public_key' # jwt 公钥
BOT_TOKEN=xxx # 机器人 token
BOT_USERNAME=xxxxBot # 机器人用户名
BOT_API_ID=111 # api id
BOT_API_HASH=aaa # api hash
BOT_MANAGER_IDS=[111,222] # 管理员 id
```
## OIDC Endpoints
Auth URL : `/oauth2/authorize`
Token URL : `/oauth2/token`
Cert URL : `/oauth2/keys`
## OIDC Client
```sql
INSERT INTO "client" ("grant_types", "response_types", "redirect_uris", "id", "client_id", "client_secret", "scope") VALUES ('authorization_code', 'code', 'https://127.0.0.1/access/callback', 'UUID', '123456', '123456', 'openid profile email');
```

View File

View File

@ -0,0 +1,8 @@
__title__ = "aioauth_fastapi"
__description__ = "aioauth integration for FastAPI."
__url__ = "https://github.com/aliev/aioauth-fastapi"
__version__ = "0.1.2"
__author__ = "Ali Aliyev"
__author_email__ = "ali@aliev.me"
__license__ = "The MIT License (MIT)"
__copyright__ = "Copyright 2021 Ali Aliyev"

38
aioauth_fastapi/forms.py Normal file
View File

@ -0,0 +1,38 @@
"""
.. code-block:: python
from aioauth_fastapi import forms
FastAPI oauth2 forms.
Used to generate an OpenAPI schema.
----
"""
from dataclasses import dataclass
from typing import Optional
from aioauth.types import GrantType, TokenType
from fastapi.params import Form
@dataclass
class TokenForm:
grant_type: Optional[GrantType] = Form(None) # type: ignore
client_id: Optional[str] = Form(None) # type: ignore
client_secret: Optional[str] = Form(None) # type: ignore
redirect_uri: Optional[str] = Form(None) # type: ignore
scope: Optional[str] = Form(None) # type: ignore
username: Optional[str] = Form(None) # type: ignore
password: Optional[str] = Form(None) # type: ignore
refresh_token: Optional[str] = Form(None) # type: ignore
code: Optional[str] = Form(None) # type: ignore
token: Optional[str] = Form(None) # type: ignore
code_verifier: Optional[str] = Form(None) # type: ignore
@dataclass
class TokenIntrospectForm:
token: Optional[str] = Form(None) # type: ignore
token_type_hint: Optional[TokenType] = Form(None) # type: ignore

113
aioauth_fastapi/router.py Normal file
View File

@ -0,0 +1,113 @@
"""
.. code-block:: python
from aioauth_fastapi import router
FastAPI routing of oauth2.
Usage example
.. code-block:: python
from aioauth_fastapi.router import get_oauth2_router
from aioauth.storage import BaseStorage
from aioauth.config import Settings
from aioauth.server import AuthorizationServer
from fastapi import FastAPI
app = FastAPI()
class SQLAlchemyCRUD(BaseStorage):
'''
SQLAlchemyCRUD methods must be implemented here.
'''
# NOTE: Redefinition of the default aioauth settings
# INSECURE_TRANSPORT must be enabled for local development only!
settings = Settings(
INSECURE_TRANSPORT=True,
)
storage = SQLAlchemyCRUD()
authorization_server = AuthorizationServer(storage)
# Include FastAPI router with oauth2 endpoints.
app.include_router(
get_oauth2_router(authorization_server, settings),
prefix="/oauth2",
tags=["oauth2"],
)
----
"""
from typing import Callable, TypeVar
from aioauth.config import Settings
from aioauth.requests import TRequest
from aioauth.server import AuthorizationServer
from aioauth.storage import TStorage
from fastapi import APIRouter, Request
from .utils import (
RequestArguments,
default_request_factory,
to_fastapi_response,
to_oauth2_request,
)
ARequest = TypeVar("ARequest", bound=TRequest)
def get_oauth2_router(
authorization_server: AuthorizationServer[ARequest, TStorage],
settings: Settings = Settings(),
request_factory: Callable[[RequestArguments], ARequest] = default_request_factory,
) -> APIRouter:
"""Function will create FastAPI router with the following oauth2 endpoints:
* POST /token
* Endpoint creates a token response by :py:meth:`aioauth.server.AuthorizationServer.create_token_response`
* POST `/token/introspect`
* Endpoint creates a token introspection by :py:meth:`aioauth.server.AuthorizationServer.create_token_introspection_response`
* GET `/authorize`
* Endpoint creates an authorization response by :py:meth:`aioauth.server.AuthorizationServer.create_authorization_response`
Returns:
:py:class:`fastapi.APIRouter`.
"""
router = APIRouter()
@router.post("/token")
async def token(request: Request):
oauth2_request = await to_oauth2_request(
request=request, request_factory=request_factory, settings=settings
)
oauth2_response = await authorization_server.create_token_response(
oauth2_request
)
return await to_fastapi_response(oauth2_response)
@router.post("/token/introspect")
async def token_introspect(request: Request):
oauth2_request = await to_oauth2_request(
request=request, request_factory=request_factory, settings=settings
)
oauth2_response = (
await authorization_server.create_token_introspection_response(
oauth2_request
)
)
return await to_fastapi_response(oauth2_response)
@router.get("/authorize")
async def authorize(request: Request):
oauth2_request = await to_oauth2_request(
request=request, request_factory=request_factory, settings=settings
)
oauth2_response = await authorization_server.create_authorization_response(
oauth2_request
)
return await to_fastapi_response(oauth2_response)
return router

98
aioauth_fastapi/utils.py Normal file
View File

@ -0,0 +1,98 @@
"""
.. code-block:: python
from aioauth_fastapi import utils
Core utils for integration with FastAPI
----
"""
import json
from dataclasses import dataclass
from typing import Callable, Dict, Optional
from aioauth.collections import HTTPHeaderDict
from aioauth.config import Settings
from aioauth.requests import Post, Query, TRequest, TUser
from aioauth.requests import Request as OAuth2Request
from aioauth.responses import Response as OAuth2Response
from fastapi import Request, Response
@dataclass
class RequestArguments:
headers: HTTPHeaderDict
method: str
post_args: Dict
query_args: Dict
settings: Settings
url: str
user: Optional[TUser]
def default_request_factory(request_args: RequestArguments) -> OAuth2Request:
return OAuth2Request(
headers=request_args.headers,
method=request_args.method, # type: ignore
post=Post(**request_args.post_args), # type: ignore
query=Query(**request_args.query_args), # type: ignore
settings=request_args.settings,
url=request_args.url,
user=request_args.user,
)
async def to_oauth2_request(
request: Request,
settings: Settings = Settings(),
request_factory: Callable[[RequestArguments], TRequest] = default_request_factory,
) -> TRequest:
"""Converts :py:class:`fastapi.Request` instance to :py:class:`aioauth.requests.Request` instance"""
form = await request.form()
post_args = dict(form)
query_args = dict(request.query_params)
need_args = [
"client_id",
"redirect_uri",
"response_type",
"state",
"scope",
"nonce",
"code_challenge_method",
"code_challenge",
"response_mode",
]
for arg in list(query_args.keys()):
if arg not in need_args:
del query_args[arg]
method = request.method
headers = HTTPHeaderDict(**request.headers)
url = str(request.url)
user = None
if request.user.is_authenticated:
user = request.user
request_args = RequestArguments(
headers=headers,
method=method,
post_args=post_args,
query_args=query_args,
settings=settings,
url=url,
user=user,
)
return request_factory(request_args)
async def to_fastapi_response(oauth2_response: OAuth2Response) -> Response:
"""Converts :py:class:`aioauth.responses.Response` instance to :py:class:`fastapi.Response` instance"""
response_content = oauth2_response.content
headers = dict(oauth2_response.headers)
status_code = oauth2_response.status_code
content = json.dumps(response_content)
return Response(content=content, headers=headers, status_code=status_code)

89
alembic.ini Normal file
View File

@ -0,0 +1,89 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic/
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to ./versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat ./versions
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite+aiosqlite:///data/db.sqlite3
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

90
alembic/env.py Normal file
View File

@ -0,0 +1,90 @@
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlmodel import SQLModel
from src.config import settings
from src.oauth2.models import * # noqa
from src.users.models import * # noqa
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
config.set_main_option("sqlalchemy.url", str(settings.CONN_URI))
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = SQLModel.metadata # type: ignore
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection):
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = AsyncEngine(
engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
)
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())

25
alembic/script.py.mako Normal file
View File

@ -0,0 +1,25 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,226 @@
"""Initial migrations
Revision ID: 07a7ace268a7
Revises:
Create Date: 2021-10-02 22:50:10.418498
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "07a7ace268a7"
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"users",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=True),
sa.Column("is_blocked", sa.Boolean(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_users_id"), "users", ["id"], unique=True)
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
op.create_index(op.f("ix_users_is_blocked"), "users", ["is_blocked"], unique=False)
op.create_index(
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
)
op.create_index(op.f("ix_users_password"), "users", ["password"], unique=False)
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
op.create_table(
"authorizationcode",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("code", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("redirect_uri", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("response_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("auth_time", sa.Integer(), nullable=False),
sa.Column("expires_in", sa.Integer(), nullable=False),
sa.Column("code_challenge", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column(
"code_challenge_method", sqlmodel.sql.sqltypes.AutoString(), nullable=True
),
sa.Column("nonce", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_authorizationcode_auth_time"),
"authorizationcode",
["auth_time"],
unique=False,
)
op.create_index(
op.f("ix_authorizationcode_client_id"),
"authorizationcode",
["client_id"],
unique=False,
)
op.create_index(
op.f("ix_authorizationcode_code"), "authorizationcode", ["code"], unique=False
)
op.create_index(
op.f("ix_authorizationcode_code_challenge"),
"authorizationcode",
["code_challenge"],
unique=False,
)
op.create_index(
op.f("ix_authorizationcode_code_challenge_method"),
"authorizationcode",
["code_challenge_method"],
unique=False,
)
op.create_index(
op.f("ix_authorizationcode_expires_in"),
"authorizationcode",
["expires_in"],
unique=False,
)
op.create_index(
op.f("ix_authorizationcode_id"), "authorizationcode", ["id"], unique=True
)
op.create_index(
op.f("ix_authorizationcode_nonce"), "authorizationcode", ["nonce"], unique=False
)
op.create_index(
op.f("ix_authorizationcode_redirect_uri"),
"authorizationcode",
["redirect_uri"],
unique=False,
)
op.create_index(
op.f("ix_authorizationcode_response_type"),
"authorizationcode",
["response_type"],
unique=False,
)
op.create_index(
op.f("ix_authorizationcode_scope"), "authorizationcode", ["scope"], unique=False
)
op.create_index(
op.f("ix_authorizationcode_user_id"),
"authorizationcode",
["user_id"],
unique=False,
)
op.create_table(
"client",
sa.Column("grant_types", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("response_types", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("redirect_uris", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("client_secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_client_client_id"), "client", ["client_id"], unique=False)
op.create_index(
op.f("ix_client_client_secret"), "client", ["client_secret"], unique=False
)
op.create_index(op.f("ix_client_id"), "client", ["id"], unique=True)
op.create_index(op.f("ix_client_scope"), "client", ["scope"], unique=False)
op.create_table(
"token",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("access_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("refresh_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("issued_at", sa.Integer(), nullable=False),
sa.Column("expires_in", sa.Integer(), nullable=False),
sa.Column("refresh_token_expires_in", sa.Integer(), nullable=False),
sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("token_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_token_client_id"), "token", ["client_id"], unique=False)
op.create_index(op.f("ix_token_expires_in"), "token", ["expires_in"], unique=False)
op.create_index(op.f("ix_token_id"), "token", ["id"], unique=True)
op.create_index(op.f("ix_token_issued_at"), "token", ["issued_at"], unique=False)
op.create_index(
op.f("ix_token_refresh_token_expires_in"),
"token",
["refresh_token_expires_in"],
unique=False,
)
op.create_index(op.f("ix_token_revoked"), "token", ["revoked"], unique=False)
op.create_index(op.f("ix_token_scope"), "token", ["scope"], unique=False)
op.create_index(op.f("ix_token_token_type"), "token", ["token_type"], unique=False)
op.create_index(op.f("ix_token_user_id"), "token", ["user_id"], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_token_user_id"), table_name="token")
op.drop_index(op.f("ix_token_token_type"), table_name="token")
op.drop_index(op.f("ix_token_scope"), table_name="token")
op.drop_index(op.f("ix_token_revoked"), table_name="token")
op.drop_index(op.f("ix_token_refresh_token_expires_in"), table_name="token")
op.drop_index(op.f("ix_token_issued_at"), table_name="token")
op.drop_index(op.f("ix_token_id"), table_name="token")
op.drop_index(op.f("ix_token_expires_in"), table_name="token")
op.drop_index(op.f("ix_token_client_id"), table_name="token")
op.drop_table("token")
op.drop_index(op.f("ix_client_scope"), table_name="client")
op.drop_index(op.f("ix_client_id"), table_name="client")
op.drop_index(op.f("ix_client_client_secret"), table_name="client")
op.drop_index(op.f("ix_client_client_id"), table_name="client")
op.drop_table("client")
op.drop_index(op.f("ix_authorizationcode_user_id"), table_name="authorizationcode")
op.drop_index(op.f("ix_authorizationcode_scope"), table_name="authorizationcode")
op.drop_index(
op.f("ix_authorizationcode_response_type"), table_name="authorizationcode"
)
op.drop_index(
op.f("ix_authorizationcode_redirect_uri"), table_name="authorizationcode"
)
op.drop_index(op.f("ix_authorizationcode_nonce"), table_name="authorizationcode")
op.drop_index(op.f("ix_authorizationcode_id"), table_name="authorizationcode")
op.drop_index(
op.f("ix_authorizationcode_expires_in"), table_name="authorizationcode"
)
op.drop_index(
op.f("ix_authorizationcode_code_challenge_method"),
table_name="authorizationcode",
)
op.drop_index(
op.f("ix_authorizationcode_code_challenge"), table_name="authorizationcode"
)
op.drop_index(op.f("ix_authorizationcode_code"), table_name="authorizationcode")
op.drop_index(
op.f("ix_authorizationcode_client_id"), table_name="authorizationcode"
)
op.drop_index(
op.f("ix_authorizationcode_auth_time"), table_name="authorizationcode"
)
op.drop_table("authorizationcode")
op.drop_index(op.f("ix_users_username"), table_name="users")
op.drop_index(op.f("ix_users_password"), table_name="users")
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
op.drop_index(op.f("ix_users_is_blocked"), table_name="users")
op.drop_index(op.f("ix_users_is_active"), table_name="users")
op.drop_index(op.f("ix_users_id"), table_name="users")
op.drop_table("users")
# ### end Alembic commands ###

View File

@ -0,0 +1,27 @@
"""tgid
Revision ID: c76c4cbb0b3b
Revises: 07a7ace268a7
Create Date: 2024-01-13 16:12:45.884304
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'c76c4cbb0b3b'
down_revision = '07a7ace268a7'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('users', sa.Column('tg_id', sa.BigInteger(), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('users', 'tg_id')
# ### end Alembic commands ###

26
gen_keys.py Normal file
View File

@ -0,0 +1,26 @@
from pathlib import Path
data_path = Path("data")
data_path.mkdir(exist_ok=True)
private_key_path = data_path / "private_key"
public_key_path = data_path / "public_key"
def gen_keys():
from Crypto.PublicKey import RSA
key = RSA.generate(2048)
private_key = key.export_key().decode("utf-8")
public_key = key.publickey().export_key().decode("utf-8")
if private_key_path.is_file() and public_key_path.is_file():
print("Keys already exist")
return
with open(private_key_path, "w") as f:
f.write(private_key)
with open(public_key_path, "w") as f:
f.write(public_key)
if __name__ == "__main__":
gen_keys()

40
html/login.jinja Normal file
View File

@ -0,0 +1,40 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Title</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.3.2/css/bootstrap.min.css"
integrity="sha512-b2QcS5SsA8tZodcDtGRELiGv5SaKSk1vDHDaQRda0htPYWZ6046lr3kJ5bAAQdpV2mmA/4v0wQF9MyU6/pDIAg=="
crossorigin="anonymous" referrerpolicy="no-referrer"/>
<style>
.main {
text-align: center;
background-color: #fff;
border-radius: 20px;
width: 260px;
height: 20px;
margin: auto;
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
}
</style>
</head>
<body>
<div class="main">
<script async src="https://telegram.org/js/telegram-widget.js?22"
data-telegram-login="{{ username }}"
data-size="large"
data-auth-url="{{ callback_url }}"
data-request-access="write">
</script>
<div>
<a href="https://t.me/{{ username }}?start=login">
<button type="button" class="btn btn-primary">通过 BOT 登录</button>
</a>
</div>
</div>
</body>
</html>

50
main.py Normal file
View File

@ -0,0 +1,50 @@
import asyncio
from signal import signal as signal_fn, SIGINT, SIGTERM, SIGABRT
from src.app import web
from src.bot import bot
from src.logs import logs
async def idle():
task = None
def signal_handler(_, __):
if web.web_server_task:
web.web_server_task.cancel()
task.cancel()
for s in (SIGINT, SIGTERM, SIGABRT):
signal_fn(s, signal_handler)
while True:
task = asyncio.create_task(asyncio.sleep(600))
web.bot_main_task = task
try:
await task
except asyncio.CancelledError:
break
async def main():
logs.info("正在启动 Web Server")
await web.start()
logs.info("正在启动 Bot")
await bot.start()
try:
logs.info("正在运行")
await idle()
finally:
try:
await bot.stop()
except ConnectionError:
pass
if web.web_server:
try:
await web.web_server.shutdown()
except AttributeError:
pass
if __name__ == "__main__":
bot.run(main())

21
pyromod/__init__.py Normal file
View File

@ -0,0 +1,21 @@
"""
pyromod - A monkeypatched add-on for Pyrogram
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
This file is part of pyromod.
pyromod is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
pyromod is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
"""
__version__ = "1.5"

View File

@ -0,0 +1,21 @@
"""
pyromod - A monkeypatcher add-on for Pyrogram
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
This file is part of pyromod.
pyromod is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
pyromod is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
"""
from .listen import Client, MessageHandler, Chat, User

157
pyromod/listen/listen.py Normal file
View File

@ -0,0 +1,157 @@
"""
pyromod - A monkeypatcher add-on for Pyrogram
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
This file is part of pyromod.
pyromod is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
pyromod is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
"""
import asyncio
import functools
import pyrogram
from src.scheduler import add_delete_message_job
from ..utils import patch, patchable
from ..utils.errors import ListenerCanceled, TimeoutConversationError
pyrogram.errors.ListenerCanceled = ListenerCanceled
@patch(pyrogram.client.Client)
class Client:
@patchable
def __init__(self, *args, **kwargs):
self.listening = {}
self.using_mod = True
self.old__init__(*args, **kwargs)
@patchable
async def listen(self, chat_id, filters=None, timeout=None):
if type(chat_id) != int:
chat = await self.get_chat(chat_id)
chat_id = chat.id
future = self.loop.create_future()
future.add_done_callback(functools.partial(self.clear_listener, chat_id))
self.listening.update({chat_id: {"future": future, "filters": filters}})
try:
return await asyncio.wait_for(future, timeout)
except asyncio.exceptions.TimeoutError as e:
raise TimeoutConversationError() from e
@patchable
async def ask(self, chat_id, text, filters=None, timeout=None, *args, **kwargs):
request = await self.send_message(chat_id, text, *args, **kwargs)
response = await self.listen(chat_id, filters, timeout)
response.request = request
return response
@patchable
def clear_listener(self, chat_id, future):
if future == self.listening[chat_id]["future"]:
self.listening.pop(chat_id, None)
@patchable
def cancel_listener(self, chat_id):
listener = self.listening.get(chat_id)
if not listener or listener["future"].done():
return
listener["future"].set_exception(ListenerCanceled())
self.clear_listener(chat_id, listener["future"])
@patchable
def cancel_all_listener(self):
for chat_id in self.listening:
self.cancel_listener(chat_id)
@patch(pyrogram.handlers.message_handler.MessageHandler)
class MessageHandler:
@patchable
def __init__(self, callback: callable, filters=None):
self.user_callback = callback
self.old__init__(self.resolve_listener, filters)
@patchable
async def resolve_listener(self, client, message, *args):
listener = client.listening.get(message.chat.id)
if listener and not listener["future"].done():
listener["future"].set_result(message)
else:
if listener and listener["future"].done():
client.clear_listener(message.chat.id, listener["future"])
await self.user_callback(client, message, *args)
@patchable
async def check(self, client, update):
listener = client.listening.get(update.chat.id)
if listener and not listener["future"].done():
return (
await listener["filters"](client, update)
if callable(listener["filters"])
else True
)
return await self.filters(client, update) if callable(self.filters) else True
@patch(pyrogram.types.user_and_chats.chat.Chat)
class Chat(pyrogram.types.Chat):
@patchable
def listen(self, *args, **kwargs):
return self._client.listen(self.id, *args, **kwargs)
@patchable
def ask(self, *args, **kwargs):
return self._client.ask(self.id, *args, **kwargs)
@patchable
def cancel_listener(self):
return self._client.cancel_listener(self.id)
@patch(pyrogram.types.user_and_chats.user.User)
class User(pyrogram.types.User):
@patchable
def listen(self, *args, **kwargs):
return self._client.listen(self.id, *args, **kwargs)
@patchable
def ask(self, *args, **kwargs):
return self._client.ask(self.id, *args, **kwargs)
@patchable
def cancel_listener(self):
return self._client.cancel_listener(self.id)
@patch(pyrogram.types.messages_and_media.Message)
class Message(pyrogram.types.Message):
@patchable
async def safe_delete(self, revoke: bool = True):
try:
return await self._client.delete_messages(
chat_id=self.chat.id, message_ids=self.id, revoke=revoke
)
except Exception as e: # noqa
return False
@patchable
async def delay_delete(self, delay: int = 60):
add_delete_message_job(self, delay)

21
pyromod/utils/__init__.py Normal file
View File

@ -0,0 +1,21 @@
"""
pyromod - A monkeypatcher add-on for Pyrogram
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
This file is part of pyromod.
pyromod is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
pyromod is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
"""
from .utils import patch, patchable

16
pyromod/utils/errors.py Normal file
View File

@ -0,0 +1,16 @@
class TimeoutConversationError(Exception):
"""
Occurs when the conversation times out.
"""
def __init__(self):
super().__init__("Response read timed out")
class ListenerCanceled(Exception):
"""
Occurs when the listener is canceled.
"""
def __init__(self):
super().__init__("Listener was canceled")

38
pyromod/utils/utils.py Normal file
View File

@ -0,0 +1,38 @@
"""
pyromod - A monkeypatcher add-on for Pyrogram
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
This file is part of pyromod.
pyromod is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
pyromod is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
"""
def patch(obj):
def is_patchable(item):
return getattr(item[1], "patchable", False)
def wrapper(container):
for name, func in filter(is_patchable, container.__dict__.items()):
old = getattr(obj, name, None)
setattr(obj, f"old{name}", old)
setattr(obj, name, func)
return container
return wrapper
def patchable(func):
func.patchable = True
return func

22
requirements.txt Normal file
View File

@ -0,0 +1,22 @@
fastapi==0.109.0
uvicorn==0.25.0
git+https://github.com/aliev/aioauth
sqlmodel==0.0.14
alembic==1.13.1
aiosqlite==0.19.0
PyCryptodome==3.20.0
python-jose[cryptography]==3.3.0
python-multipart==0.0.6
orjson==3.9.10
jinja2==3.1.3
pydantic~=2.5.3
pydantic-settings==2.1.0
SQLAlchemy~=2.0.25
starlette~=0.35.1
pyrogram==2.0.106
tgcrypto==1.2.5
pytz~=2023.3.post1
APScheduler~=3.10.4
coloredlogs~=15.0.1
httpx==0.26.0
asyncmy==0.2.9

0
src/__init__.py Normal file
View File

87
src/app.py Normal file
View File

@ -0,0 +1,87 @@
import asyncio
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware
from .config import settings
from .events import on_shutdown, on_startup
from .logs import logs
from .oauth2 import endpoints as oauth2_endpoints
from .users import endpoints as users_endpoints
from .users.backends import TokenAuthenticationBackend
class Web:
def __init__(self):
self.app = FastAPI(
title=settings.PROJECT_NAME,
docs_url="/api/openapi",
openapi_url="/api/openapi.json",
default_response_class=ORJSONResponse,
on_startup=on_startup,
on_shutdown=on_shutdown,
)
self.web_server = None
self.web_server_task = None
self.bot_main_task = None
def init_web(self):
# Include API router
self.app.include_router(users_endpoints.router, prefix="/api/users", tags=["users"])
# Define aioauth-fastapi endpoints
self.app.include_router(
oauth2_endpoints.router,
prefix="/oauth2",
tags=["oauth2"],
)
self.app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
self.app.add_middleware(AuthenticationMiddleware, backend=TokenAuthenticationBackend())
async def start(self):
import uvicorn
self.init_web()
self.web_server = uvicorn.Server(
config=uvicorn.Config(
self.app,
host=settings.PROJECT_HOST,
port=settings.PROJECT_PORT,
reload=settings.DEBUG,
)
)
server_config = self.web_server.config
server_config.setup_event_loop()
if not server_config.loaded:
server_config.load()
self.web_server.lifespan = server_config.lifespan_class(server_config)
try:
await self.web_server.startup()
except OSError as e:
if e.errno == 10048:
logs.error("Web Server 端口被占用:%s", e)
logs.error("Web Server 启动失败,正在退出")
raise SystemExit from None
if self.web_server.should_exit:
logs.error("Web Server 启动失败,正在退出")
raise SystemExit from None
logs.info("Web Server 启动成功")
self.web_server_task = asyncio.create_task(self.web_server.main_loop())
async def stop(self):
if self.web_server_task:
self.web_server_task.cancel()
if self.bot_main_task:
self.bot_main_task.cancel()
web = Web()

13
src/bot.py Normal file
View File

@ -0,0 +1,13 @@
import pyromod.listen
from pyrogram import Client
from .config import settings
bot = Client(
"bot",
bot_token=settings.BOT_TOKEN,
api_id=settings.BOT_API_ID,
api_hash=settings.BOT_API_HASH,
plugins={"root": "src.telegram.plugins"},
workdir="data",
)

35
src/config.py Normal file
View File

@ -0,0 +1,35 @@
from typing import List
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
PROJECT_NAME: str = "Telegram OAuth"
PROJECT_URL: str = "http://127.0.0.1:8081"
PROJECT_LOGIN_SUCCESS_URL: str = "http://google.com"
PROJECT_HOST: str = "127.0.0.1"
PROJECT_PORT: int = 8001
DEBUG: bool = True
CONN_URI: str
JWT_PUBLIC_KEY: str
JWT_PRIVATE_KEY: str
ACCESS_TOKEN_EXP: int = 900 # 15 minutes
REFRESH_TOKEN_EXP: int = 86400 # 1 day
CORS_ORIGINS: List[str] = ["*"]
BOT_TOKEN: str
BOT_USERNAME: str
BOT_API_ID: int
BOT_API_HASH: str
BOT_MANAGER_IDS: List[int]
class Config:
env_file = ".env"
case_sensitive = True
settings = Settings()

27
src/events.py Normal file
View File

@ -0,0 +1,27 @@
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlmodel.ext.asyncio.session import AsyncSession
from .config import settings
from .storage import sqlalchemy
async def create_sqlalchemy_connection():
# NOTE: https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html#using-multiple-asyncio-event-loops
engine = create_async_engine(settings.CONN_URI, echo=True, poolclass=NullPool)
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
sqlalchemy.sqlalchemy_session = async_session()
async def close_sqlalchemy_connection():
if sqlalchemy.sqlalchemy_session is not None:
await sqlalchemy.sqlalchemy_session.close()
on_startup = [
create_sqlalchemy_connection,
]
on_shutdown = [
close_sqlalchemy_connection,
]

3
src/html.py Normal file
View File

@ -0,0 +1,3 @@
from starlette.templating import Jinja2Templates
templates = Jinja2Templates(directory="html")

21
src/logs.py Normal file
View File

@ -0,0 +1,21 @@
from logging import getLogger, StreamHandler, basicConfig, INFO, CRITICAL, ERROR
from coloredlogs import ColoredFormatter
logs = getLogger("telegram-oauth")
logging_format = "%(levelname)s [%(asctime)s] [%(name)s] %(message)s"
logging_handler = StreamHandler()
logging_handler.setFormatter(ColoredFormatter(logging_format))
root_logger = getLogger()
root_logger.setLevel(CRITICAL)
root_logger.addHandler(logging_handler)
pyro_logger = getLogger("pyrogram")
pyro_logger.setLevel(INFO)
sql_logger = getLogger("sqlalchemy")
sql_logger.setLevel(CRITICAL)
sql_engine_logger = getLogger("sqlalchemy.engine.Engine")
sql_engine_logger.setLevel(CRITICAL)
aioauth_logger = getLogger("aioauth")
aioauth_logger.setLevel(INFO)
basicConfig(level=ERROR)
logs.setLevel(INFO)

0
src/oauth2/__init__.py Normal file
View File

56
src/oauth2/endpoints.py Normal file
View File

@ -0,0 +1,56 @@
from aioauth.config import Settings
from aioauth.oidc.core.grant_type import AuthorizationCodeGrantType
from aioauth.requests import Request as OAuth2Request
from aioauth.server import AuthorizationServer
from fastapi import APIRouter, Depends, Request
from aioauth_fastapi.utils import to_fastapi_response, to_oauth2_request
from .storage import Storage
from ..config import settings as local_settings
from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
from ..users.crypto import get_pub_key_resp
from ..utils.oauth import to_login_request
router = APIRouter()
settings = Settings(
TOKEN_EXPIRES_IN=local_settings.ACCESS_TOKEN_EXP,
REFRESH_TOKEN_EXPIRES_IN=local_settings.REFRESH_TOKEN_EXP,
INSECURE_TRANSPORT=local_settings.DEBUG,
)
grant_types = {
"authorization_code": AuthorizationCodeGrantType,
}
@router.post("/token")
async def token(
request: Request,
storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
):
oauth2_storage = Storage(storage=storage)
authorization_server = AuthorizationServer(storage=oauth2_storage, grant_types=grant_types)
oauth2_request: OAuth2Request = await to_oauth2_request(request, settings)
oauth2_response = await authorization_server.create_token_response(oauth2_request)
return await to_fastapi_response(oauth2_response)
@router.get("/authorize")
async def authorize(
request: Request,
storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
):
if not request.user.is_authenticated:
return await to_login_request(request)
oauth2_storage = Storage(storage=storage)
authorization_server = AuthorizationServer(storage=oauth2_storage, grant_types=grant_types)
oauth2_request: OAuth2Request = await to_oauth2_request(request, settings)
oauth2_response = await authorization_server.create_authorization_response(
oauth2_request
)
return await to_fastapi_response(oauth2_response)
@router.get("/keys")
async def keys():
return get_pub_key_resp()

50
src/oauth2/models.py Normal file
View File

@ -0,0 +1,50 @@
from typing import TYPE_CHECKING, Optional
from pydantic.types import UUID4
from sqlmodel.main import Field, Relationship
from ..storage.models import BaseTable
if TYPE_CHECKING: # pragma: no cover
from ..users.models import User
class Client(BaseTable, table=True): # type: ignore
client_id: str
client_secret: str
grant_types: str
response_types: str
redirect_uris: str
scope: str
class AuthorizationCode(BaseTable, table=True): # type: ignore
code: str
client_id: str
redirect_uri: str
response_type: str
scope: str
auth_time: int
expires_in: int
code_challenge: Optional[str]
code_challenge_method: Optional[str]
nonce: Optional[str]
user_id: UUID4 = Field(foreign_key="users.id", nullable=False)
user: "User" = Relationship(back_populates="user_authorization_codes")
class Token(BaseTable, table=True): # type: ignore
access_token: str
refresh_token: str
scope: str
issued_at: int
expires_in: int
refresh_token_expires_in: int
client_id: str
token_type: str
revoked: bool
user_id: UUID4 = Field(foreign_key="users.id", nullable=False)
user: "User" = Relationship(back_populates="user_tokens")

289
src/oauth2/storage.py Normal file
View File

@ -0,0 +1,289 @@
from datetime import datetime, timezone
from typing import Optional
from aioauth.models import AuthorizationCode, Client, Token
from aioauth.requests import Request
from aioauth.storage import BaseStorage
from aioauth.types import CodeChallengeMethod, ResponseType, TokenType
from aioauth.utils import enforce_list
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import delete
from .models import AuthorizationCode as AuthorizationCodeDB
from .models import Client as ClientDB
from .models import Token as TokenDB
from ..config import settings
from ..storage.sqlalchemy import SQLAlchemyStorage
from ..users.crypto import encode_jwt, get_jwt
from ..users.models import User
class Storage(BaseStorage):
def __init__(self, storage: SQLAlchemyStorage):
self.storage = storage
async def get_user(self, request: Request):
user: Optional[User] = None
if request.query.response_type == "token":
# If ResponseType is token get the user from current session
user = request.user
if request.post.grant_type == "authorization_code":
# If GrantType is authorization code get user from DB by code
q_results = await self.storage.select(
select(AuthorizationCodeDB).where(
AuthorizationCodeDB.code == request.post.code
)
)
authorization_code: Optional[AuthorizationCodeDB]
authorization_code = q_results.scalars().one_or_none()
if not authorization_code:
return
q_results = await self.storage.select(
select(User).where(User.id == authorization_code.user_id)
)
user = q_results.scalars().one_or_none()
if request.post.grant_type == "refresh_token":
# Get user from token
q_results = await self.storage.select(
select(TokenDB)
.where(TokenDB.refresh_token == request.post.refresh_token)
.options(selectinload(TokenDB.user))
)
token: Optional[TokenDB]
token = q_results.scalars().one_or_none()
if not token:
return
user = token.user
return user
async def create_token(
self,
request: Request,
client_id: str,
scope: str,
access_token: str,
refresh_token: str,
) -> Token:
"""
Create token and store it in storage.
"""
user = await self.get_user(request)
_access_token, _refresh_token = get_jwt(user)
token = Token(
access_token=_access_token,
client_id=client_id,
expires_in=300,
issued_at=int(datetime.now(tz=timezone.utc).timestamp()),
refresh_token=_refresh_token,
refresh_token_expires_in=900,
revoked=False,
scope=scope,
token_type="Bearer",
user=user,
)
token_record = TokenDB(
access_token=token.access_token,
refresh_token=token.refresh_token,
scope=token.scope,
issued_at=token.issued_at,
expires_in=token.expires_in,
refresh_token_expires_in=token.refresh_token_expires_in,
client_id=token.client_id,
token_type=token.token_type,
revoked=token.revoked,
user_id=user.id,
)
await self.storage.add(token_record)
return token
async def revoke_token(
self,
request: Request,
token_type: Optional[TokenType] = "refresh_token",
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
) -> None:
"""
Remove refresh_token from whitelist.
"""
q_results = await self.storage.select(
select(TokenDB).where(TokenDB.refresh_token == refresh_token)
)
token_record: Optional[TokenDB]
token_record = q_results.scalars().one_or_none()
if token_record:
token_record.revoked = True
await self.storage.add(token_record)
async def get_token(
self,
request: Request,
client_id: str,
token_type: Optional[str] = "refresh_token",
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
) -> Optional[Token]:
if token_type == "refresh_token":
q = select(TokenDB).where(TokenDB.refresh_token == refresh_token)
else:
q = select(TokenDB).where(TokenDB.access_token == access_token)
q_results = await self.storage.select(
q.where(TokenDB.revoked == False).options( # noqa
selectinload(TokenDB.user)
)
)
token_record: Optional[TokenDB]
token_record = q_results.scalars().one_or_none()
if token_record:
return Token(
access_token=token_record.access_token,
refresh_token=token_record.refresh_token,
scope=token_record.scope,
issued_at=token_record.issued_at,
expires_in=token_record.expires_in,
refresh_token_expires_in=token_record.refresh_token_expires_in,
client_id=client_id,
)
async def create_authorization_code(
self,
request: Request,
client_id: str,
scope: str,
response_type: ResponseType,
redirect_uri: str,
code_challenge_method: Optional[CodeChallengeMethod],
code_challenge: Optional[str],
code: str,
**kwargs,
) -> AuthorizationCode:
authorization_code = AuthorizationCode(
auth_time=int(datetime.now(tz=timezone.utc).timestamp()),
client_id=client_id,
code=code,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
expires_in=300,
redirect_uri=redirect_uri,
response_type=response_type,
scope=scope,
user=request.user,
)
authorization_code_record = AuthorizationCodeDB(
code=authorization_code.code,
client_id=authorization_code.client_id,
redirect_uri=authorization_code.redirect_uri,
response_type=authorization_code.response_type,
scope=authorization_code.scope,
auth_time=authorization_code.auth_time,
expires_in=authorization_code.expires_in,
code_challenge_method=authorization_code.code_challenge_method,
code_challenge=authorization_code.code_challenge,
nonce=authorization_code.nonce,
user_id=request.user.id,
)
await self.storage.add(authorization_code_record)
return authorization_code
async def get_client(
self, request: Request, client_id: str, client_secret: Optional[str] = None
) -> Optional[Client]:
q_results = await self.storage.select(
select(ClientDB).where(ClientDB.client_id == client_id)
)
client_record: Optional[ClientDB]
client_record = q_results.scalars().one_or_none()
if not client_record:
return None
return Client(
client_id=client_record.client_id,
client_secret=client_record.client_secret,
grant_types=[client_record.grant_types],
response_types=[client_record.response_types],
redirect_uris=[client_record.redirect_uris],
scope=client_record.scope,
)
async def get_authorization_code(
self, request: Request, client_id: str, code: str
) -> Optional[AuthorizationCode]:
q_results = await self.storage.select(
select(AuthorizationCodeDB).where(AuthorizationCodeDB.code == code)
)
authorization_code_record: Optional[AuthorizationCode]
authorization_code_record = q_results.scalars().one_or_none()
if not authorization_code_record:
return None
return AuthorizationCode(
code=authorization_code_record.code,
client_id=authorization_code_record.client_id,
redirect_uri=authorization_code_record.redirect_uri,
response_type=authorization_code_record.response_type,
scope=authorization_code_record.scope,
auth_time=authorization_code_record.auth_time,
expires_in=authorization_code_record.expires_in,
code_challenge=authorization_code_record.code_challenge,
code_challenge_method=authorization_code_record.code_challenge_method,
nonce=authorization_code_record.nonce,
)
async def delete_authorization_code(
self, request: Request, client_id: str, code: str
) -> None:
await self.storage.delete(
delete(AuthorizationCodeDB).where(AuthorizationCodeDB.code == code)
)
async def get_id_token(
self,
request: Request,
client_id: str,
scope: str,
response_type: ResponseType,
redirect_uri: str,
**kwargs,
) -> str:
scopes = enforce_list(scope)
user = await self.get_user(request)
user_data = {}
if "email" in scopes:
user_data["email"] = user.username
user_data["username"] = user.username
return encode_jwt(
expires_delta=settings.ACCESS_TOKEN_EXP,
sub=str(user.id),
additional_claims=user_data,
)

33
src/scheduler.py Normal file
View File

@ -0,0 +1,33 @@
import contextlib
import datetime
from typing import TYPE_CHECKING
import pytz
from apscheduler.schedulers.asyncio import AsyncIOScheduler
if TYPE_CHECKING:
from src.telegram.enums import Message
scheduler = AsyncIOScheduler(timezone="Asia/ShangHai")
if not scheduler.running:
scheduler.start()
async def delete_message(message: "Message") -> bool:
with contextlib.suppress(Exception):
await message.delete()
return True
return False
def add_delete_message_job(message: "Message", delete_seconds: int = 60):
scheduler.add_job(
delete_message,
"date",
id=f"{message.chat.id}|{message.id}|delete_message",
name=f"{message.chat.id}|{message.id}|delete_message",
args=[message],
run_date=datetime.datetime.now(pytz.timezone("Asia/Shanghai"))
+ datetime.timedelta(seconds=delete_seconds),
replace_existing=True,
)

0
src/storage/__init__.py Normal file
View File

14
src/storage/models.py Normal file
View File

@ -0,0 +1,14 @@
import uuid
from pydantic.types import UUID4
from sqlmodel import Field, SQLModel
class BaseTable(SQLModel):
id: UUID4 = Field(
primary_key=True,
default_factory=uuid.uuid4,
nullable=False,
index=True,
sa_column_kwargs={"unique": True},
)

70
src/storage/sqlalchemy.py Normal file
View File

@ -0,0 +1,70 @@
from typing import Optional
from sqlalchemy.engine.result import Result
from sqlalchemy.sql.expression import Delete, Update
from sqlalchemy.sql.selectable import Select
from sqlmodel.ext.asyncio.session import AsyncSession
class SQLAlchemyTransaction:
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def __aenter__(self) -> "SQLAlchemyTransaction":
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
if exc_type is None:
await self.commit()
else:
await self.rollback()
await self.close()
async def rollback(self):
await self.session.rollback()
async def commit(self):
await self.session.commit()
async def close(self):
await self.session.close()
class SQLAlchemyStorage:
def __init__(
self, session: AsyncSession, transaction: SQLAlchemyTransaction
) -> None:
self.session = session
self.transaction = transaction
async def select(self, q: Select) -> Result:
async with self.transaction:
return await self.session.execute(q)
async def add(self, model) -> None:
async with self.transaction:
self.session.add(model)
async def delete(self, q: Delete) -> None:
async with self.transaction:
await self.session.execute(q)
async def update(self, q: Update):
async with self.transaction:
await self.session.execute(q)
sqlalchemy_session: Optional[AsyncSession] = None
def get_sqlalchemy_storage() -> SQLAlchemyStorage:
"""Get SQLAlchemy storage instance.
Returns:
SQLAlchemyStorage: SQLAlchemy storage instance
"""
sqllachemy_trancation = SQLAlchemyTransaction(session=sqlalchemy_session)
return SQLAlchemyStorage(
session=sqlalchemy_session, transaction=sqllachemy_trancation
)

0
src/telegram/__init__.py Normal file
View File

25
src/telegram/enums.py Normal file
View File

@ -0,0 +1,25 @@
from typing import Optional
from pyrogram import Client as PyroClient
from pyrogram.types import Message as PyroMessage
class Client(PyroClient): # noqa
async def listen(self, chat_id, filters=None, timeout=None) -> Optional["Message"]:
return
async def ask(
self, chat_id, text, filters=None, timeout=None, *args, **kwargs
) -> Optional["Message"]:
return
def cancel_listener(self, chat_id):
"""Cancel the conversation with the given chat_id."""
class Message(PyroMessage): # noqa
async def delay_delete(self, delete_seconds: int = 60) -> Optional[bool]:
return
async def safe_delete(self, revoke: bool = True) -> None:
return

15
src/telegram/filters.py Normal file
View File

@ -0,0 +1,15 @@
from typing import TYPE_CHECKING
from pyrogram.filters import create
from src.config import settings
if TYPE_CHECKING:
from .enums import Message
async def admin_filter(_, __, m: "Message"):
return bool(m.from_user and m.from_user.id in settings.BOT_MANAGER_IDS)
admin = create(admin_filter)

8
src/telegram/message.py Normal file
View File

@ -0,0 +1,8 @@
import re
NO_ACCOUNT_MSG = """UID `%s` 还没有注册账号,请联系管理员注册账号。"""
ACCOUNT_MSG = """UID: `%s`\n邮箱: `%s`"""
REG_MSG = """请发送需要使用的邮箱"""
MAIL_REGEX = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
LOGIN_MSG = """请点击下面的按钮登录:"""
LOGIN_BUTTON = """跳转登录"""

View File

View File

@ -0,0 +1,24 @@
from pyrogram import filters
from src.bot import bot
from src.config import settings
from src.telegram.enums import Client, Message
from src.telegram.message import ACCOUNT_MSG, NO_ACCOUNT_MSG
from src.users.crud import get_user_crud
async def account(message: Message, uid: int):
crud = get_user_crud()
user = await crud.get_by_tg_id(uid)
if user:
await message.reply(ACCOUNT_MSG % (user.tg_id, user.username), quote=True)
else:
await message.reply(NO_ACCOUNT_MSG % uid, quote=True)
@bot.on_message(filters=filters.private & filters.command("account"))
async def get_account(_: Client, message: Message):
uid = message.from_user.id
if uid in settings.BOT_MANAGER_IDS and len(message.command) >= 2 and message.command[1].isnumeric():
uid = int(message.command[1])
await account(message, uid)

View File

@ -0,0 +1,44 @@
from pyrogram import filters
from pyromod.utils.errors import TimeoutConversationError
from src.bot import bot
from src.logs import logs
from src.telegram.enums import Client, Message
from src.telegram.filters import admin
from src.telegram.message import REG_MSG, MAIL_REGEX
from src.users.crud import get_user_crud
async def reg(client: Client, from_id: int, uid: int):
msg_ = await client.send_message(from_id, REG_MSG)
try:
msg = await client.listen(from_id, filters=filters.text, timeout=60)
except TimeoutConversationError:
await msg_.edit("响应超时,请重试")
return
if msg.text and MAIL_REGEX.match(msg.text):
crud = get_user_crud()
try:
user = await crud.get_by_tg_id(uid)
if user:
await crud.update(user, username=msg.text)
else:
await crud.create(
username=msg.text,
password="1",
tg_id=uid,
)
except Exception as e:
logs.exception("注册失败", exc_info=e)
await msg.reply_text("注册失败")
await msg.reply_text("注册成功")
else:
await msg.reply_text("邮箱格式错误")
@bot.on_message(filters=filters.private & filters.command("edit") & admin)
async def edit_account(client: Client, message: Message):
uid = from_id = message.from_user.id
if len(message.command) >= 2 and message.command[1].isnumeric():
uid = int(message.command[1])
await reg(client, from_id, uid)

View File

@ -0,0 +1,41 @@
from httpx import URL
from pyrogram import filters, Client
from pyrogram.types import InlineKeyboardMarkup, InlineKeyboardButton
from src.bot import bot
from src.config import settings
from src.telegram.enums import Message
from src.telegram.message import NO_ACCOUNT_MSG, LOGIN_MSG, LOGIN_BUTTON
from src.users.crud import get_user_crud
from src.utils.telegram import encode_telegram_auth_data
async def login(message: Message):
uid = message.from_user.id
crud = get_user_crud()
user = await crud.get_by_tg_id(uid)
if not user:
await message.reply(NO_ACCOUNT_MSG % uid, quote=True)
return
token = await encode_telegram_auth_data(uid)
url = settings.PROJECT_URL + "/api/users/auth"
url = URL(url).copy_add_param("jwt", token)
url = str(url)
await message.reply(
LOGIN_MSG,
quote=True,
reply_markup=InlineKeyboardMarkup(
[[InlineKeyboardButton(LOGIN_BUTTON, url=url)]]
),
)
@bot.on_message(filters=filters.private & filters.command("start"))
async def start(client: Client, message: Message):
if message.command and len(message.command) >= 2:
action = message.command[1]
if action == "login":
await login(message)
return
me = await client.get_me()
await message.reply(f"Hello, I'm {me.first_name}. ")

0
src/users/__init__.py Normal file
View File

32
src/users/backends.py Normal file
View File

@ -0,0 +1,32 @@
from fastapi.security.utils import get_authorization_scheme_param
from starlette.authentication import AuthCredentials, AuthenticationBackend
from .crypto import authenticate, read_rsa_key_from_env
from .models import User, UserAnonymous
from ..config import settings
class TokenAuthenticationBackend(AuthenticationBackend):
async def authenticate(self, request):
authorization: str = request.headers.get("Authorization")
_, bearer_token = get_authorization_scheme_param(authorization)
token: str = request.cookies.get("access_token") or bearer_token
if not token:
return AuthCredentials(), UserAnonymous()
key = read_rsa_key_from_env(settings.JWT_PUBLIC_KEY)
is_authenticated, decoded_token = authenticate(token=token, key=key)
if is_authenticated:
return AuthCredentials(), User(
id=decoded_token["sub"],
is_superuser=decoded_token["is_superuser"],
is_blocked=decoded_token["is_blocked"],
is_active=decoded_token["is_active"],
username=decoded_token["username"],
)
return AuthCredentials(), UserAnonymous()

40
src/users/crud.py Normal file
View File

@ -0,0 +1,40 @@
from typing import Optional
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import Update
from .models import User
from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
class SQLAlchemyCRUD:
def __init__(self, storage: SQLAlchemyStorage):
self.storage = storage
async def get_by_tg_id(self, tg_id: int) -> Optional[User]:
q_results = await self.storage.select(
select(User)
.options(
# for relationship loading, eager loading should be applied.
selectinload(User.user_tokens)
)
.where(User.tg_id == tg_id)
)
return q_results.scalars().one_or_none()
async def create(self, **kwargs) -> None:
user = User(**kwargs)
await self.storage.add(user)
async def update(self, user: User, **kwargs) -> None:
await self.storage.update(
Update(User).where(User.id == user.id).values(**kwargs)
)
def get_user_crud(storage: SQLAlchemyStorage = None) -> SQLAlchemyCRUD:
if storage is None:
storage = get_sqlalchemy_storage()
return SQLAlchemyCRUD(storage=storage)

166
src/users/crypto.py Normal file
View File

@ -0,0 +1,166 @@
import base64
import pathlib
import re
import string
import uuid
from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple
from Crypto.PublicKey import RSA
from jose import constants, jwt
from jose.exceptions import JWTError
from ..config import settings
RANDOM_STRING_CHARS = string.ascii_lowercase + string.ascii_uppercase + string.digits
KEYS = {}
def reformat_rsa_key(rsa_key: str) -> str:
"""Reformat an RSA PEM key without newlines to one with correct newline characters
@param rsa_key: the PEM RSA key lacking newline characters
@return: the reformatted PEM RSA key with appropriate newline characters
"""
# split headers from the body
split_rsa_key = re.split(r"(-+)", rsa_key)
# add newlines between headers and body
split_rsa_key.insert(4, "\n")
split_rsa_key.insert(6, "\n")
reformatted_rsa_key = "".join(split_rsa_key)
# reformat body
return RSA.importKey(reformatted_rsa_key).exportKey().decode("utf-8")
def read_rsa_key_from_env(file_path: str) -> str:
if file_path in KEYS:
return KEYS[file_path]
path = pathlib.Path(file_path)
# path to rsa key file
if path.is_file():
with open(file_path, "rb") as key_file:
jwt_private_key = RSA.importKey(key_file.read()).exportKey()
k = jwt_private_key.decode("utf-8")
KEYS[file_path] = k
return k
# rsa key without newlines
if "\n" not in file_path:
k = reformat_rsa_key(file_path)
KEYS[file_path] = k
return k
return file_path
def get_n(rsa: RSA):
bytes_data = rsa.n.to_bytes((rsa.n.bit_length() + 7) // 8, 'big')
return base64.urlsafe_b64encode(bytes_data).decode('utf-8')
def get_pub_key_resp():
pub_key = RSA.importKey(read_rsa_key_from_env(settings.JWT_PUBLIC_KEY))
return {
"keys": [
{
"n": get_n(pub_key),
"kty": "RSA",
"alg": "RS256",
"kid": "sig",
"e": "AQAB",
"use": "sig"
}
]
}
def encode_jwt(
expires_delta,
sub,
secret=None,
additional_claims=None,
algorithm=constants.ALGORITHMS.RS256,
):
if additional_claims is None:
additional_claims = {}
if secret is None:
secret = read_rsa_key_from_env(settings.JWT_PRIVATE_KEY)
now = datetime.now(timezone.utc)
claims = {
"iat": now,
"jti": str(uuid.uuid4()),
"nbf": now,
"sub": sub,
"exp": now + timedelta(seconds=expires_delta),
**additional_claims,
}
return jwt.encode(
claims,
secret,
algorithm,
)
def decode_jwt(
encoded_token,
secret=None,
algorithms=None,
):
if algorithms is None:
algorithms = constants.ALGORITHMS.RS256
if secret is None:
secret = read_rsa_key_from_env(settings.JWT_PRIVATE_KEY)
return jwt.decode(
encoded_token,
secret,
algorithms=algorithms,
)
def get_jwt(user):
access_token = encode_jwt(
sub=str(user.id),
expires_delta=settings.ACCESS_TOKEN_EXP,
additional_claims={
"token_type": "access",
"is_blocked": user.is_blocked,
"is_superuser": user.is_superuser,
"username": user.username,
"is_active": user.is_active,
},
)
refresh_token = encode_jwt(
sub=str(user.id),
expires_delta=settings.REFRESH_TOKEN_EXP,
additional_claims={
"token_type": "refresh",
"is_blocked": user.is_blocked,
"is_superuser": user.is_superuser,
"username": user.username,
"is_active": user.is_active,
},
)
return access_token, refresh_token
def authenticate(
*,
token: str,
key: str,
) -> Tuple[bool, Dict]:
"""Authenticate user by token"""
try:
token_header = jwt.get_unverified_header(token)
decoded_token = jwt.decode(token, key, algorithms=token_header.get("alg"))
except JWTError:
return False, {}
else:
return True, decoded_token

79
src/users/endpoints.py Normal file
View File

@ -0,0 +1,79 @@
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException
from jose import JWTError
from starlette.requests import Request
from .crud import SQLAlchemyCRUD
from .crypto import get_jwt
from ..config import settings
from ..html import templates
from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
from ..utils.oauth import back_auth_request
from ..utils.redirect import RedirectResponseBuilder
from ..utils.telegram import decode_telegram_auth_data, verify_telegram_auth_data
router = APIRouter()
@router.get("/login", name="users:login:get")
async def user_login_get(request: Request):
if request.user.is_authenticated:
if resp := await back_auth_request(request):
return resp
return RedirectResponseBuilder().build(settings.PROJECT_LOGIN_SUCCESS_URL)
url = request.url
callback_url = str(url).replace("/login", "/callback")
return templates.TemplateResponse(
"login.jinja",
{"request": request, "callback_url": callback_url, "username": settings.BOT_USERNAME}
)
async def auth(
tg_id: int,
request: Request,
storage: SQLAlchemyStorage,
):
if tg_id is None:
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
crud = SQLAlchemyCRUD(storage=storage)
user = await crud.get_by_tg_id(tg_id=tg_id)
if user is None:
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
access_token, refresh_token = get_jwt(user)
# NOTE: Setting expire causes an exception for requests library:
# https://github.com/psf/requests/issues/6004
if resp := await back_auth_request(request, access_token, refresh_token):
return resp
resp = RedirectResponseBuilder()
resp.set_cookie(
key="access_token", value=access_token, max_age=settings.ACCESS_TOKEN_EXP
)
resp.set_cookie(
key="refresh_token", value=refresh_token, max_age=settings.REFRESH_TOKEN_EXP
)
return resp.build(settings.PROJECT_LOGIN_SUCCESS_URL)
@router.get("/callback", name="users:login")
async def user_login(
request: Request,
storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
):
tg_id = await verify_telegram_auth_data(request.query_params)
return await auth(tg_id, request, storage)
@router.get("/auth", name="users:auth")
async def user_auth(
request: Request,
storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
):
try:
tg_id = await decode_telegram_auth_data(request.query_params)
except JWTError:
tg_id = None
return await auth(tg_id, request, storage)

37
src/users/models.py Normal file
View File

@ -0,0 +1,37 @@
from typing import TYPE_CHECKING, List, Optional
from pydantic import BaseModel
from sqlalchemy import Column, BigInteger
from sqlmodel.main import Field, Relationship
from ..storage.models import BaseTable
if TYPE_CHECKING: # pragma: no cover
from ..oauth2.models import AuthorizationCode, Token
class UserAnonymous(BaseModel):
@property
def is_authenticated(self) -> bool:
return False
class User(BaseTable, table=True): # type: ignore
__tablename__ = "users"
is_superuser: bool = False
is_blocked: bool = False
is_active: bool = False
username: str = Field(nullable=False, sa_column_kwargs={"unique": True}, index=True)
password: Optional[str] = None
tg_id: int = Field(sa_column=Column(BigInteger(), nullable=False))
user_authorization_codes: List["AuthorizationCode"] = Relationship(
back_populates="user"
)
user_tokens: List["Token"] = Relationship(back_populates="user")
@property
def is_authenticated(self) -> bool:
return True

0
src/utils/__init__.py Normal file
View File

38
src/utils/oauth.py Normal file
View File

@ -0,0 +1,38 @@
from typing import Optional
from starlette.requests import Request
from starlette.responses import RedirectResponse
from src.config import settings
from src.users.crypto import encode_jwt, decode_jwt
from src.utils.redirect import RedirectResponseBuilder
async def to_login_request(request: Request) -> RedirectResponse:
query_params = dict(request.query_params)
params = ""
for key, value in query_params.items():
params += f"{key}={value}&"
params = params[:-1]
jwt = encode_jwt(settings.ACCESS_TOKEN_EXP, "", additional_claims={"params": params})
resp = RedirectResponseBuilder()
resp.set_cookie("SEND", jwt, max_age=settings.ACCESS_TOKEN_EXP)
return resp.build("/api/users/login")
async def back_auth_request(
request: Request,
access_token: str = None,
refresh_token: str = None,
) -> Optional[RedirectResponse]:
cookie = request.cookies.get("SEND")
if cookie is None:
return None
params = decode_jwt(cookie)["params"]
resp = RedirectResponseBuilder()
if access_token:
resp.set_cookie("access_token", access_token, max_age=settings.ACCESS_TOKEN_EXP)
if refresh_token:
resp.set_cookie("refresh_token", refresh_token, max_age=settings.ACCESS_TOKEN_EXP)
resp.delete_cookie("SEND")
return resp.build(f"/oauth2/authorize?{params}", status_code=303)

78
src/utils/redirect.py Normal file
View File

@ -0,0 +1,78 @@
import http.cookies
import typing
from datetime import datetime
from email.utils import format_datetime
from starlette.datastructures import MutableHeaders
from starlette.responses import RedirectResponse
class RedirectResponseBuilder:
def __init__(self):
self.raw_headers = []
def set_cookie(
self,
key: str,
value: str = "",
max_age: typing.Optional[int] = None,
expires: typing.Optional[typing.Union[datetime, str, int]] = None,
path: str = "/",
domain: typing.Optional[str] = None,
secure: bool = False,
httponly: bool = False,
samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax",
) -> None:
cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie()
cookie[key] = value
if max_age is not None:
cookie[key]["max-age"] = max_age
if expires is not None:
if isinstance(expires, datetime):
cookie[key]["expires"] = format_datetime(expires, usegmt=True)
else:
cookie[key]["expires"] = expires
if path is not None:
cookie[key]["path"] = path
if domain is not None:
cookie[key]["domain"] = domain
if secure:
cookie[key]["secure"] = True
if httponly:
cookie[key]["httponly"] = True
if samesite is not None:
assert samesite.lower() in [
"strict",
"lax",
"none",
], "samesite must be either 'strict', 'lax' or 'none'"
cookie[key]["samesite"] = samesite
cookie_val = cookie.output(header="").strip()
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
def delete_cookie(
self,
key: str,
path: str = "/",
domain: typing.Optional[str] = None,
secure: bool = False,
httponly: bool = False,
samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax",
) -> None:
self.set_cookie(
key,
max_age=0,
expires=0,
path=path,
domain=domain,
secure=secure,
httponly=httponly,
samesite=samesite,
)
@property
def headers(self) -> MutableHeaders:
return MutableHeaders(raw=self.raw_headers)
def build(self, url: str, status_code: int = 307):
return RedirectResponse(url, headers=self.headers, status_code=status_code)

45
src/utils/telegram.py Normal file
View File

@ -0,0 +1,45 @@
import hashlib
import hmac
from datetime import datetime, timezone
from typing import Optional
from starlette.datastructures import QueryParams
from src.config import settings
from src.users.crypto import encode_jwt, decode_jwt
async def verify_telegram_auth_data(params: QueryParams) -> Optional[int]:
data = list(params.items())
hash_str = ""
text_list = []
for key, value in data:
if key == "hash":
hash_str = value
else:
text_list.append(f"{key}={value}")
check_string = "\n".join(sorted(text_list))
secret_key = hashlib.sha256(str.encode(settings.BOT_TOKEN)).digest()
hmac_hash = hmac.new(secret_key, str.encode(check_string), hashlib.sha256).hexdigest()
return int(params.get("id")) if hmac_hash == hash_str else None
async def encode_telegram_auth_data(uid: int) -> str:
jwt = encode_jwt(settings.ACCESS_TOKEN_EXP, str(uid))
return jwt
async def decode_telegram_auth_data(params: QueryParams) -> Optional[int]:
jwt = params.get("jwt")
if not jwt:
return None
if not jwt:
return None
data = decode_jwt(jwt)
now = datetime.now(timezone.utc)
uid, exp = data["sub"], data["exp"]
if exp < now.timestamp():
return None
return int(uid)